AverageLearner: implement min_npoints
Bas Nijholt authored on 14/05/2020 20:17:38 • GitHub committed on 14/05/2020 20:17:38... | ... |
@@ -19,6 +19,8 @@ class AverageLearner(BaseLearner): |
19 | 19 |
Desired absolute tolerance. |
20 | 20 |
rtol : float |
21 | 21 |
Desired relative tolerance. |
22 |
+ min_npoints : int |
|
23 |
+ Minimum number of points to sample. |
|
22 | 24 |
|
23 | 25 |
Attributes |
24 | 26 |
---------- |
... | ... |
@@ -30,7 +32,7 @@ class AverageLearner(BaseLearner): |
30 | 32 |
Number of evaluated points. |
31 | 33 |
""" |
32 | 34 |
|
33 |
- def __init__(self, function, atol=None, rtol=None): |
|
35 |
+ def __init__(self, function, atol=None, rtol=None, min_npoints=2): |
|
34 | 36 |
if atol is None and rtol is None: |
35 | 37 |
raise Exception("At least one of `atol` and `rtol` should be set.") |
36 | 38 |
if atol is None: |
... | ... |
@@ -44,6 +46,8 @@ class AverageLearner(BaseLearner): |
44 | 46 |
self.atol = atol |
45 | 47 |
self.rtol = rtol |
46 | 48 |
self.npoints = 0 |
49 |
+ # Cannot estimate standard deviation with fewer than 2 points. |
|
50 |
+ self.min_npoints = max(min_npoints, 2) |
|
47 | 51 |
self.sum_f = 0 |
48 | 52 |
self.sum_f_sq = 0 |
49 | 53 |
|
... | ... |
@@ -92,7 +96,7 @@ class AverageLearner(BaseLearner): |
92 | 96 |
"""The corrected sample standard deviation of the values |
93 | 97 |
in `data`.""" |
94 | 98 |
n = self.npoints |
95 |
- if n < 2: |
|
99 |
+ if n < self.min_npoints: |
|
96 | 100 |
return np.inf |
97 | 101 |
numerator = self.sum_f_sq - n * self.mean ** 2 |
98 | 102 |
if numerator < 0: |
... | ... |
@@ -106,7 +110,7 @@ class AverageLearner(BaseLearner): |
106 | 110 |
n = self.npoints if real else self.n_requested |
107 | 111 |
else: |
108 | 112 |
n = n |
109 |
- if n < 2: |
|
113 |
+ if n < self.min_npoints: |
|
110 | 114 |
return np.inf |
111 | 115 |
standard_error = self.std / sqrt(n) |
112 | 116 |
return max( |
... | ... |
@@ -150,10 +154,11 @@ class AverageLearner(BaseLearner): |
150 | 154 |
self.function, |
151 | 155 |
self.atol, |
152 | 156 |
self.rtol, |
157 |
+ self.min_npoints, |
|
153 | 158 |
self._get_data(), |
154 | 159 |
) |
155 | 160 |
|
156 | 161 |
def __setstate__(self, state): |
157 |
- function, atol, rtol, data = state |
|
158 |
- self.__init__(function, atol, rtol) |
|
162 |
+ function, atol, rtol, min_npoints, data = state |
|
163 |
+ self.__init__(function, atol, rtol, min_npoints) |
|
159 | 164 |
self._set_data(data) |
... | ... |
@@ -4,6 +4,7 @@ import flaky |
4 | 4 |
import numpy as np |
5 | 5 |
|
6 | 6 |
from adaptive.learner import AverageLearner |
7 |
+from adaptive.runner import simple |
|
7 | 8 |
|
8 | 9 |
|
9 | 10 |
def test_only_returns_new_points(): |
... | ... |
@@ -46,3 +47,15 @@ def test_avg_std_and_npoints(): |
46 | 47 |
assert learner.npoints == len(learner.data) |
47 | 48 |
assert abs(learner.sum_f - values.sum()) < 1e-13 |
48 | 49 |
assert abs(learner.std - std) < 1e-13 |
50 |
+ |
|
51 |
+ |
|
52 |
+def test_min_npoints(): |
|
53 |
+ def constant_function(seed): |
|
54 |
+ return 0.1 |
|
55 |
+ |
|
56 |
+ for min_npoints in [1, 2, 3]: |
|
57 |
+ learner = AverageLearner( |
|
58 |
+ constant_function, atol=0.01, rtol=0.01, min_npoints=min_npoints |
|
59 |
+ ) |
|
60 |
+ simple(learner, lambda l: l.loss() < 1) |
|
61 |
+ assert learner.npoints >= max(2, min_npoints) |