Browse code

2D: simplify choose_points

Bas Nijholt authored on 24/11/2017 14:34:44
Showing 1 changed files
... ...
@@ -260,40 +260,51 @@ class Learner2D(BaseLearner):
260 260
 
261 261
         losses = self.loss_per_triangle(ip)
262 262
 
263
+        points_new = []
264
+        losses_new = []
263 265
         for j, _ in enumerate(losses):
264 266
             jsimplex = np.argmax(losses)
265 267
             triangle = ip.tri.points[ip.tri.vertices[jsimplex]]
266 268
             point_new = choose_point_in_triangle(triangle, max_badness=5)
267 269
             point_new = tuple(self.unscale(point_new))
270
+            loss_new = losses[jsimplex]
268 271
 
269
-            self._stack[point_new] = losses[jsimplex]
272
+            points_new.append(point_new)
273
+            losses_new.append(loss_new)
274
+
275
+            self._stack[point_new] = loss_new
270 276
 
271 277
             if len(self._stack) >= stack_till:
272 278
                 break
273 279
             else:
274 280
                 losses[jsimplex] = -np.inf
275 281
 
282
+        return points_new, losses_new
283
+
276 284
     def _split_stack(self, n=None):
277
-        points, loss_improvements = zip(*reversed(self._stack.items()))
278
-        return points[:n], loss_improvements[:n]
285
+        if self._stack:
286
+            points, loss_improvements = zip(*reversed(self._stack.items()))
287
+            return list(points[:n]), list(loss_improvements[:n])
288
+        else:
289
+            return [], []
279 290
 
280 291
     def _choose_and_add_points(self, n):
281
-        points = []
282
-        loss_improvements = []
283 292
         n_left = n
293
+        points, loss_improvements = self._split_stack(n_left)
294
+        self.add_data(points, itertools.repeat(None))
295
+        n_left -= len(points)
296
+
284 297
         while n_left > 0:
285 298
             # The while loop is needed because `stack_till` could be larger
286 299
             # than the number of triangles between the points. Therefore
287 300
             # it could fill up till a length smaller than `stack_till`.
288
-            if not any(p in self._stack for p in self._bounds_points):
289
-                self._fill_stack(stack_till=max(n_left, 10))
290
-            new_points, new_loss_improvements = self._split_stack(n_left)
301
+            new_points, new_loss_improvements = self._fill_stack(stack_till=max(n_left, 10))
291 302
             points += new_points
292 303
             loss_improvements += new_loss_improvements
293
-            self.add_data(new_points, itertools.repeat(None))
294 304
             n_left -= len(new_points)
305
+            self.add_data(new_points, itertools.repeat(None))
295 306
 
296
-        return points, loss_improvements
307
+        return points[:n], loss_improvements[:n]
297 308
 
298 309
     def choose_points(self, n, add_data=True):
299 310
         if not add_data: