Browse code

Merge pull request #274 from python-adaptive/feature/min_npoints

AverageLearner: implement min_npoints

Bas Nijholt authored on 14/05/2020 20:17:38 • GitHub committed on 14/05/2020 20:17:38
Showing 2 changed files
... ...
@@ -19,6 +19,8 @@ class AverageLearner(BaseLearner):
19 19
         Desired absolute tolerance.
20 20
     rtol : float
21 21
         Desired relative tolerance.
22
+    min_npoints : int
23
+        Minimum number of points to sample.
22 24
 
23 25
     Attributes
24 26
     ----------
... ...
@@ -30,7 +32,7 @@ class AverageLearner(BaseLearner):
30 32
         Number of evaluated points.
31 33
     """
32 34
 
33
-    def __init__(self, function, atol=None, rtol=None):
35
+    def __init__(self, function, atol=None, rtol=None, min_npoints=2):
34 36
         if atol is None and rtol is None:
35 37
             raise Exception("At least one of `atol` and `rtol` should be set.")
36 38
         if atol is None:
... ...
@@ -44,6 +46,8 @@ class AverageLearner(BaseLearner):
44 46
         self.atol = atol
45 47
         self.rtol = rtol
46 48
         self.npoints = 0
49
+        # Cannot estimate standard deviation with fewer than 2 points.
50
+        self.min_npoints = max(min_npoints, 2)
47 51
         self.sum_f = 0
48 52
         self.sum_f_sq = 0
49 53
 
... ...
@@ -92,7 +96,7 @@ class AverageLearner(BaseLearner):
92 96
         """The corrected sample standard deviation of the values
93 97
         in `data`."""
94 98
         n = self.npoints
95
-        if n < 2:
99
+        if n < self.min_npoints:
96 100
             return np.inf
97 101
         numerator = self.sum_f_sq - n * self.mean ** 2
98 102
         if numerator < 0:
... ...
@@ -106,7 +110,7 @@ class AverageLearner(BaseLearner):
106 110
             n = self.npoints if real else self.n_requested
107 111
         else:
108 112
             n = n
109
-        if n < 2:
113
+        if n < self.min_npoints:
110 114
             return np.inf
111 115
         standard_error = self.std / sqrt(n)
112 116
         return max(
... ...
@@ -150,10 +154,11 @@ class AverageLearner(BaseLearner):
150 154
             self.function,
151 155
             self.atol,
152 156
             self.rtol,
157
+            self.min_npoints,
153 158
             self._get_data(),
154 159
         )
155 160
 
156 161
     def __setstate__(self, state):
157
-        function, atol, rtol, data = state
158
-        self.__init__(function, atol, rtol)
162
+        function, atol, rtol, min_npoints, data = state
163
+        self.__init__(function, atol, rtol, min_npoints)
159 164
         self._set_data(data)
... ...
@@ -4,6 +4,7 @@ import flaky
4 4
 import numpy as np
5 5
 
6 6
 from adaptive.learner import AverageLearner
7
+from adaptive.runner import simple
7 8
 
8 9
 
9 10
 def test_only_returns_new_points():
... ...
@@ -46,3 +47,15 @@ def test_avg_std_and_npoints():
46 47
             assert learner.npoints == len(learner.data)
47 48
             assert abs(learner.sum_f - values.sum()) < 1e-13
48 49
             assert abs(learner.std - std) < 1e-13
50
+
51
+
52
+def test_min_npoints():
53
+    def constant_function(seed):
54
+        return 0.1
55
+
56
+    for min_npoints in [1, 2, 3]:
57
+        learner = AverageLearner(
58
+            constant_function, atol=0.01, rtol=0.01, min_npoints=min_npoints
59
+        )
60
+        simple(learner, lambda l: l.loss() < 1)
61
+        assert learner.npoints >= max(2, min_npoints)