Browse code

simplify test_average_learner.py::test_min_npoints further

Bas Nijholt authored on 14/05/2020 14:53:03
Showing 1 changed files
... ...
@@ -50,11 +50,12 @@ def test_avg_std_and_npoints():
50 50
 
51 51
 
52 52
 def test_min_npoints():
53
-    def f(seed):
54
-        if seed < 2:  # first two numbers are similar
55
-            return 0.1 + 1e-8 * random.random()
56
-        return random.random()
57
-
58
-    learner = AverageLearner(f, atol=0.01, rtol=0.01, min_npoints=3)
59
-    simple(learner, lambda l: l.loss() < 1)
60
-    assert learner.npoints > 2
53
+    def constant_function(seed):
54
+        return 0.1
55
+
56
+    for min_npoints in [1, 2, 3]:
57
+        learner = AverageLearner(
58
+            constant_function, atol=0.01, rtol=0.01, min_npoints=min_npoints
59
+        )
60
+        simple(learner, lambda l: l.loss() < 1)
61
+        assert learner.npoints >= max(2, min_npoints)