... | ... |
@@ -50,11 +50,12 @@ def test_avg_std_and_npoints(): |
50 | 50 |
|
51 | 51 |
|
52 | 52 |
def test_min_npoints(): |
53 |
- def f(seed): |
|
54 |
- if seed < 2: # first two numbers are similar |
|
55 |
- return 0.1 + 1e-8 * random.random() |
|
56 |
- return random.random() |
|
57 |
- |
|
58 |
- learner = AverageLearner(f, atol=0.01, rtol=0.01, min_npoints=3) |
|
59 |
- simple(learner, lambda l: l.loss() < 1) |
|
60 |
- assert learner.npoints > 2 |
|
53 |
+ def constant_function(seed): |
|
54 |
+ return 0.1 |
|
55 |
+ |
|
56 |
+ for min_npoints in [1, 2, 3]: |
|
57 |
+ learner = AverageLearner( |
|
58 |
+ constant_function, atol=0.01, rtol=0.01, min_npoints=min_npoints |
|
59 |
+ ) |
|
60 |
+ simple(learner, lambda l: l.loss() < 1) |
|
61 |
+ assert learner.npoints >= max(2, min_npoints) |