...
|
...
|
@@ -1,6 +1,7 @@
|
1
|
1
|
# -*- coding: utf-8 -*-
|
2
|
2
|
|
3
|
3
|
from adaptive.learner import Learner1D, BalancingLearner
|
|
4
|
+from adaptive.runner import simple
|
4
|
5
|
|
5
|
6
|
|
6
|
7
|
def test_balancing_learner_loss_cache():
|
...
|
...
|
@@ -24,9 +25,23 @@ def test_balancing_learner_loss_cache():
|
24
|
25
|
|
25
|
26
|
|
26
|
27
|
def test_distribute_first_points_over_learners():
|
27
|
|
- learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
|
28
|
|
- learner = BalancingLearner(learners)
|
29
|
|
- points, _ = learner.ask(100)
|
30
|
|
- i_learner, xs = zip(*points)
|
31
|
|
- # assert that are all learners in the suggested points
|
32
|
|
- assert len(set(i_learner)) == len(learners)
|
|
28
|
+ for strategy in ['loss', 'loss_improvements', 'npoints']:
|
|
29
|
+ learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
|
|
30
|
+ learner = BalancingLearner(learners, strategy=strategy)
|
|
31
|
+ points, _ = learner.ask(100)
|
|
32
|
+ i_learner, xs = zip(*points)
|
|
33
|
+ # assert that are all learners in the suggested points
|
|
34
|
+ assert len(set(i_learner)) == len(learners)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+def test_strategies():
|
|
38
|
+ goals = {
|
|
39
|
+ 'loss': lambda l: l.loss() < 0.1,
|
|
40
|
+ 'loss_improvements': lambda l: l.loss() < 0.1,
|
|
41
|
+ 'npoints': lambda bl: all(l.npoints > 10 for l in bl.learners)
|
|
42
|
+ }
|
|
43
|
+
|
|
44
|
+ for strategy, goal in goals.items():
|
|
45
|
+ learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
|
|
46
|
+ learner = BalancingLearner(learners, strategy=strategy)
|
|
47
|
+ simple(learner, goal=goal)
|