... | ... |
@@ -72,6 +72,7 @@ class BaseLearner(metaclass=abc.ABCMeta): |
72 | 72 |
(possibly by interpolation). |
73 | 73 |
""" |
74 | 74 |
|
75 |
+ @abc.abstractmethod |
|
75 | 76 |
def choose_points(self, n, add_data=True): |
76 | 77 |
"""Choose the next 'n' points to evaluate. |
77 | 78 |
|
... | ... |
@@ -85,22 +86,7 @@ class BaseLearner(metaclass=abc.ABCMeta): |
85 | 86 |
values. Set this to False if you do not |
86 | 87 |
want to modify the state of the learner. |
87 | 88 |
""" |
88 |
- points, loss_improvements = self._choose_points(n) |
|
89 |
- if add_data: |
|
90 |
- self.add_data(points, itertools.repeat(None)) |
|
91 |
- return points, loss_improvements |
|
92 |
- |
|
93 |
- @abc.abstractmethod |
|
94 |
- def _choose_points(self, n): |
|
95 |
- """Choose the next 'n' points to evaluate. |
|
96 |
- |
|
97 |
- Should be overridden by subclasses. |
|
98 |
- |
|
99 |
- Parameters |
|
100 |
- ---------- |
|
101 |
- n : int |
|
102 |
- The number of points to choose. |
|
103 |
- """ |
|
89 |
+ pass |
|
104 | 90 |
|
105 | 91 |
def __getstate__(self): |
106 | 92 |
return copy(self.__dict__) |
... | ... |
@@ -139,9 +125,11 @@ class AverageLearner(BaseLearner): |
139 | 125 |
self.sum_f = 0 |
140 | 126 |
self.sum_f_sq = 0 |
141 | 127 |
|
142 |
- def _choose_points(self, n=10): |
|
128 |
+ def choose_points(self, n=10, add_data=True): |
|
143 | 129 |
points = list(range(self.n_requested, self.n_requested + n)) |
144 | 130 |
loss_improvements = [None] * n |
131 |
+ if add_data: |
|
132 |
+ self.add_data(points, itertools.repeat(None)) |
|
145 | 133 |
return points, loss_improvements |
146 | 134 |
|
147 | 135 |
def add_point(self, n, value): |
... | ... |
@@ -324,7 +312,7 @@ class Learner1D(BaseLearner): |
324 | 312 |
self._oldscale = self._scale |
325 | 313 |
|
326 | 314 |
|
327 |
- def _choose_points(self, n=10): |
|
315 |
+ def choose_points(self, n=10, add_data=True): |
|
328 | 316 |
"""Return n points that are expected to maximally reduce the loss.""" |
329 | 317 |
# Find out how to divide the n points over the intervals |
330 | 318 |
# by finding positive integer n_i that minimize max(L_i / n_i) subject |
... | ... |
@@ -373,7 +361,10 @@ class Learner1D(BaseLearner): |
373 | 361 |
itertools.repeat(-quality, n) |
374 | 362 |
for quality, x, n in quals)) |
375 | 363 |
|
376 |
- return (xs, loss_improvements) |
|
364 |
+ if add_data: |
|
365 |
+ self.add_data(points, itertools.repeat(None)) |
|
366 |
+ |
|
367 |
+ return xs, loss_improvements |
|
377 | 368 |
|
378 | 369 |
def interpolate(self, extra_points=None): |
379 | 370 |
xs = list(self.data.keys()) |
... | ... |
@@ -462,9 +453,6 @@ class BalancingLearner(BaseLearner): |
462 | 453 |
else: |
463 | 454 |
return self._choose_and_add_points(n) |
464 | 455 |
|
465 |
- def _choose_points(self, n): |
|
466 |
- pass |
|
467 |
- |
|
468 | 456 |
def add_point(self, x, y): |
469 | 457 |
index, x = x |
470 | 458 |
self.learners[index].add_point(x, y) |
... | ... |
@@ -793,9 +781,6 @@ class Learner2D(BaseLearner): |
793 | 781 |
else: |
794 | 782 |
dev[jsimplex] = 0 |
795 | 783 |
|
796 |
- def _choose_points(self, n): |
|
797 |
- pass |
|
798 |
- |
|
799 | 784 |
def _choose_and_add_points(self, n): |
800 | 785 |
if n <= len(self._stack): |
801 | 786 |
points = self._stack[:n] |