Browse code

use pytest.mark.parametrize

Bas Nijholt authored on 17/03/2019 17:31:04
Showing 1 changed files
... ...
@@ -1,5 +1,7 @@
1 1
 # -*- coding: utf-8 -*-
2 2
 
3
+import pytest
4
+
3 5
 from adaptive.learner import Learner1D, BalancingLearner
4 6
 from adaptive.runner import simple
5 7
 
... ...
@@ -24,32 +26,30 @@ def test_balancing_learner_loss_cache():
24 26
     assert bl.loss(real=True) == real_loss
25 27
 
26 28
 
27
-def test_distribute_first_points_over_learners():
28
-    for strategy in ['loss', 'loss_improvements', 'npoints']:
29
-        learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
30
-        learner = BalancingLearner(learners, strategy=strategy)
31
-        points, _ = learner.ask(100)
32
-        i_learner, xs = zip(*points)
33
-        # assert that are all learners in the suggested points
34
-        assert len(set(i_learner)) == len(learners)
35
-
36
-
37
-def test_ask_0():
38
-    for strategy in ['loss', 'loss_improvements', 'npoints']:
39
-        learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
40
-        learner = BalancingLearner(learners, strategy=strategy)
41
-        points, _ = learner.ask(0)
42
-        assert len(points) == 0
43
-
44
-
45
-def test_strategies():
46
-    goals = {
47
-        'loss': lambda l: l.loss() < 0.1,
48
-        'loss_improvements': lambda l: l.loss() < 0.1,
49
-        'npoints': lambda bl: all(l.npoints > 10 for l in bl.learners)
50
-    }
51
-
52
-    for strategy, goal in goals.items():
53
-        learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
54
-        learner = BalancingLearner(learners, strategy=strategy)
55
-        simple(learner, goal=goal)
29
+@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
30
+def test_distribute_first_points_over_learners(strategy):
31
+    learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
32
+    learner = BalancingLearner(learners, strategy=strategy)
33
+    points, _ = learner.ask(100)
34
+    i_learner, xs = zip(*points)
35
+    # assert that are all learners in the suggested points
36
+    assert len(set(i_learner)) == len(learners)
37
+
38
+
39
+@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
40
+def test_ask_0(strategy):
41
+    learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
42
+    learner = BalancingLearner(learners, strategy=strategy)
43
+    points, _ = learner.ask(0)
44
+    assert len(points) == 0
45
+
46
+
47
+@pytest.mark.parametrize('strategy, goal', [
48
+    ('loss', lambda l: l.loss() < 0.1),
49
+    ('loss_improvements', lambda l: l.loss() < 0.1),
50
+    ('npoints', lambda bl: all(l.npoints > 10 for l in bl.learners)),
51
+])
52
+def test_strategies(strategy, goal):
53
+    learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
54
+    learner = BalancingLearner(learners, strategy=strategy)
55
+    simple(learner, goal=goal)