Browse code

remove _choose_points, closes #19

Bas Nijholt authored on 07/09/2017 10:49:15
Showing 1 changed files
... ...
@@ -72,6 +72,7 @@ class BaseLearner(metaclass=abc.ABCMeta):
72 72
             (possibly by interpolation).
73 73
         """
74 74
 
75
+    @abc.abstractmethod
75 76
     def choose_points(self, n, add_data=True):
76 77
         """Choose the next 'n' points to evaluate.
77 78
 
... ...
@@ -85,22 +86,7 @@ class BaseLearner(metaclass=abc.ABCMeta):
85 86
             values. Set this to False if you do not
86 87
             want to modify the state of the learner.
87 88
         """
88
-        points, loss_improvements = self._choose_points(n)
89
-        if add_data:
90
-            self.add_data(points, itertools.repeat(None))
91
-        return points, loss_improvements
92
-
93
-    @abc.abstractmethod
94
-    def _choose_points(self, n):
95
-        """Choose the next 'n' points to evaluate.
96
-
97
-        Should be overridden by subclasses.
98
-
99
-        Parameters
100
-        ----------
101
-        n : int
102
-            The number of points to choose.
103
-        """
89
+        pass
104 90
 
105 91
     def __getstate__(self):
106 92
         return copy(self.__dict__)
... ...
@@ -139,9 +125,11 @@ class AverageLearner(BaseLearner):
139 125
         self.sum_f = 0
140 126
         self.sum_f_sq = 0
141 127
 
142
-    def _choose_points(self, n=10):
128
+    def choose_points(self, n=10, add_data=True):
143 129
         points = list(range(self.n_requested, self.n_requested + n))
144 130
         loss_improvements = [None] * n
131
+        if add_data:
132
+            self.add_data(points, itertools.repeat(None))
145 133
         return points, loss_improvements
146 134
 
147 135
     def add_point(self, n, value):
... ...
@@ -324,7 +312,7 @@ class Learner1D(BaseLearner):
324 312
                     self._oldscale = self._scale
325 313
 
326 314
 
327
-    def _choose_points(self, n=10):
315
+    def choose_points(self, n=10, add_data=True):
328 316
         """Return n points that are expected to maximally reduce the loss."""
329 317
         # Find out how to divide the n points over the intervals
330 318
         # by finding  positive integer n_i that minimize max(L_i / n_i) subject
... ...
@@ -373,7 +361,10 @@ class Learner1D(BaseLearner):
373 361
                                  itertools.repeat(-quality, n)
374 362
                                  for quality, x, n in quals))
375 363
 
376
-        return (xs, loss_improvements)
364
+        if add_data:
365
+            self.add_data(points, itertools.repeat(None))
366
+
367
+        return xs, loss_improvements
377 368
 
378 369
     def interpolate(self, extra_points=None):
379 370
         xs = list(self.data.keys())
... ...
@@ -462,9 +453,6 @@ class BalancingLearner(BaseLearner):
462 453
         else:
463 454
             return self._choose_and_add_points(n)
464 455
 
465
-    def _choose_points(self, n):
466
-        pass
467
-
468 456
     def add_point(self, x, y):
469 457
         index, x = x
470 458
         self.learners[index].add_point(x, y)
... ...
@@ -793,9 +781,6 @@ class Learner2D(BaseLearner):
793 781
             else:
794 782
                 dev[jsimplex] = 0
795 783
 
796
-    def _choose_points(self, n):
797
-        pass
798
-
799 784
     def _choose_and_add_points(self, n):
800 785
         if n <= len(self._stack):
801 786
             points = self._stack[:n]