Browse code

2D: remove _split_stack

Bas Nijholt authored on 27/11/2017 17:16:46
Showing 1 changed files
... ...
@@ -269,19 +269,16 @@ class Learner2D(BaseLearner):
269 269
 
270 270
         return points_new, losses_new
271 271
 
272
-    def _split_stack(self, n=None):
273
-        if self._stack:
274
-            points, loss_improvements = zip(*reversed(self._stack.items()))
275
-            return list(points[:n]), list(loss_improvements[:n])
276
-        else:
277
-            return [], []
278 272
 
279 273
     def choose_points(self, n, add_data=True):
280
-        n_left = n
281
-        points, loss_improvements = self._split_stack(n_left)
282
-        n_left -= len(points)
283
-        # Even if add_data is False we add the point such that
284
-        # _fill_stack will return new points, later we remove these points.
274
+        # Even if add_data is False we add the point such that _fill_stack
275
+        # will return new points, later we remove these points if needed.
276
+        points, loss_improvements = [], []
277
+        for i, (point, loss_improvement) in enumerate(self._stack.items()):
278
+            if i < n:
279
+                points.append(point)
280
+                loss_improvements.append(loss_improvement)
281
+        n_left = n - len(points)
285 282
         self.add_data(points, itertools.repeat(None))
286 283
 
287 284
         while n_left > 0: