... | ... |
@@ -50,17 +50,11 @@ def test_avg_std_and_npoints(): |
50 | 50 |
|
51 | 51 |
|
52 | 52 |
def test_min_npoints(): |
53 |
- def f(npoints_similar: int): |
|
54 |
- def _f(seed): |
|
55 |
- if seed < npoints_similar: |
|
56 |
- return 0.1 + 1e-8 * random.random() |
|
57 |
- return random.random() |
|
58 |
- |
|
59 |
- return _f |
|
60 |
- |
|
61 |
- for npoints_similar in range(1, 5): |
|
62 |
- learner = AverageLearner( |
|
63 |
- f(npoints_similar), atol=0.01, rtol=0.01, min_npoints=npoints_similar + 1 |
|
64 |
- ) |
|
65 |
- simple(learner, lambda l: l.loss() < 1) |
|
66 |
- assert learner.npoints > npoints_similar |
|
53 |
+ def f(seed): |
|
54 |
+ if seed < 2: # first two numbers are similar |
|
55 |
+ return 0.1 + 1e-8 * random.random() |
|
56 |
+ return random.random() |
|
57 |
+ |
|
58 |
+ learner = AverageLearner(f, atol=0.01, rtol=0.01, min_npoints=3) |
|
59 |
+ simple(learner, lambda l: l.loss() < 1) |
|
60 |
+ assert learner.npoints > 2 |