Browse code

test the different strategies

Bas Nijholt authored on 17/03/2019 11:18:17
Showing 1 changed files
... ...
@@ -1,6 +1,7 @@
1 1
 # -*- coding: utf-8 -*-
2 2
 
3 3
 from adaptive.learner import Learner1D, BalancingLearner
4
+from adaptive.runner import simple
4 5
 
5 6
 
6 7
 def test_balancing_learner_loss_cache():
... ...
@@ -24,9 +25,23 @@ def test_balancing_learner_loss_cache():
24 25
 
25 26
 
26 27
 def test_distribute_first_points_over_learners():
27
-    learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
28
-    learner = BalancingLearner(learners)
29
-    points, _ = learner.ask(100)
30
-    i_learner, xs = zip(*points)
31
-    # assert that are all learners in the suggested points
32
-    assert len(set(i_learner)) == len(learners)
28
+    for strategy in ['loss', 'loss_improvements', 'npoints']:
29
+        learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
30
+        learner = BalancingLearner(learners, strategy=strategy)
31
+        points, _ = learner.ask(100)
32
+        i_learner, xs = zip(*points)
33
+        # assert that are all learners in the suggested points
34
+        assert len(set(i_learner)) == len(learners)
35
+
36
+
37
+def test_strategies():
38
+    goals = {
39
+        'loss': lambda l: l.loss() < 0.1,
40
+        'loss_improvements': lambda l: l.loss() < 0.1,
41
+        'npoints': lambda bl: all(l.npoints > 10 for l in bl.learners)
42
+    }
43
+
44
+    for strategy, goal in goals.items():
45
+        learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
46
+        learner = BalancingLearner(learners, strategy=strategy)
47
+        simple(learner, goal=goal)