...
|
...
|
@@ -269,19 +269,16 @@ class Learner2D(BaseLearner):
|
269
|
269
|
|
270
|
270
|
return points_new, losses_new
|
271
|
271
|
|
272
|
|
- def _split_stack(self, n=None):
|
273
|
|
- if self._stack:
|
274
|
|
- points, loss_improvements = zip(*reversed(self._stack.items()))
|
275
|
|
- return list(points[:n]), list(loss_improvements[:n])
|
276
|
|
- else:
|
277
|
|
- return [], []
|
278
|
272
|
|
279
|
273
|
def choose_points(self, n, add_data=True):
|
280
|
|
- n_left = n
|
281
|
|
- points, loss_improvements = self._split_stack(n_left)
|
282
|
|
- n_left -= len(points)
|
283
|
|
- # Even if add_data is False we add the point such that
|
284
|
|
- # _fill_stack will return new points, later we remove these points.
|
|
274
|
+ # Even if add_data is False we add the point such that _fill_stack
|
|
275
|
+ # will return new points, later we remove these points if needed.
|
|
276
|
+ points, loss_improvements = [], []
|
|
277
|
+ for i, (point, loss_improvement) in enumerate(self._stack.items()):
|
|
278
|
+ if i < n:
|
|
279
|
+ points.append(point)
|
|
280
|
+ loss_improvements.append(loss_improvement)
|
|
281
|
+ n_left = n - len(points)
|
285
|
282
|
self.add_data(points, itertools.repeat(None))
|
286
|
283
|
|
287
|
284
|
while n_left > 0:
|