Browse code

AverageLearner: implement min_npoints

Closes https://github.com/python-adaptive/adaptive/issues/273

Bas Nijholt authored on 14/05/2020 13:56:08
Showing 2 changed files
... ...
@@ -19,6 +19,8 @@ class AverageLearner(BaseLearner):
19 19
         Desired absolute tolerance.
20 20
     rtol : float
21 21
         Desired relative tolerance.
22
+    min_npoints : int
23
+        Minimum number of points required to estimate the standard deviation.
22 24
 
23 25
     Attributes
24 26
     ----------
... ...
@@ -30,7 +32,9 @@ class AverageLearner(BaseLearner):
30 32
         Number of evaluated points.
31 33
     """
32 34
 
33
-    def __init__(self, function, atol=None, rtol=None):
35
+    def __init__(self, function, atol=None, rtol=None, min_npoints=2):
36
+        if min_npoints < 2:
37
+            raise ValueError("`min_npoints` should be at least 2.")
34 38
         if atol is None and rtol is None:
35 39
             raise Exception("At least one of `atol` and `rtol` should be set.")
36 40
         if atol is None:
... ...
@@ -44,6 +48,7 @@ class AverageLearner(BaseLearner):
44 48
         self.atol = atol
45 49
         self.rtol = rtol
46 50
         self.npoints = 0
51
+        self.min_npoints = min_npoints
47 52
         self.sum_f = 0
48 53
         self.sum_f_sq = 0
49 54
 
... ...
@@ -92,7 +97,7 @@ class AverageLearner(BaseLearner):
92 97
         """The corrected sample standard deviation of the values
93 98
         in `data`."""
94 99
         n = self.npoints
95
-        if n < 2:
100
+        if n < self.min_npoints:
96 101
             return np.inf
97 102
         numerator = self.sum_f_sq - n * self.mean ** 2
98 103
         if numerator < 0:
... ...
@@ -106,7 +111,7 @@ class AverageLearner(BaseLearner):
106 111
             n = self.npoints if real else self.n_requested
107 112
         else:
108 113
             n = n
109
-        if n < 2:
114
+        if n < self.min_npoints:
110 115
             return np.inf
111 116
         standard_error = self.std / sqrt(n)
112 117
         return max(
... ...
@@ -4,6 +4,7 @@ import flaky
4 4
 import numpy as np
5 5
 
6 6
 from adaptive.learner import AverageLearner
7
+from adaptive.runner import simple
7 8
 
8 9
 
9 10
 def test_only_returns_new_points():
... ...
@@ -46,3 +47,20 @@ def test_avg_std_and_npoints():
46 47
             assert learner.npoints == len(learner.data)
47 48
             assert abs(learner.sum_f - values.sum()) < 1e-13
48 49
             assert abs(learner.std - std) < 1e-13
50
+
51
+
52
+def test_min_npoints():
53
+    def f(npoints_similar: int):
54
+        def _f(seed):
55
+            if seed < npoints_similar:
56
+                return 0.1 + 1e-8 * random.random()
57
+            return random.random()
58
+
59
+        return _f
60
+
61
+    for npoints_similar in range(1, 5):
62
+        learner = AverageLearner(
63
+            f(npoints_similar), atol=0.01, rtol=0.01, min_npoints=npoints_similar + 1
64
+        )
65
+        simple(learner, lambda l: l.loss() < 1)
66
+        assert learner.npoints > npoints_similar