... | ... |
@@ -260,40 +260,51 @@ class Learner2D(BaseLearner): |
260 | 260 |
|
261 | 261 |
losses = self.loss_per_triangle(ip) |
262 | 262 |
|
263 |
+ points_new = [] |
|
264 |
+ losses_new = [] |
|
263 | 265 |
for j, _ in enumerate(losses): |
264 | 266 |
jsimplex = np.argmax(losses) |
265 | 267 |
triangle = ip.tri.points[ip.tri.vertices[jsimplex]] |
266 | 268 |
point_new = choose_point_in_triangle(triangle, max_badness=5) |
267 | 269 |
point_new = tuple(self.unscale(point_new)) |
270 |
+ loss_new = losses[jsimplex] |
|
268 | 271 |
|
269 |
- self._stack[point_new] = losses[jsimplex] |
|
272 |
+ points_new.append(point_new) |
|
273 |
+ losses_new.append(loss_new) |
|
274 |
+ |
|
275 |
+ self._stack[point_new] = loss_new |
|
270 | 276 |
|
271 | 277 |
if len(self._stack) >= stack_till: |
272 | 278 |
break |
273 | 279 |
else: |
274 | 280 |
losses[jsimplex] = -np.inf |
275 | 281 |
|
282 |
+ return points_new, losses_new |
|
283 |
+ |
|
276 | 284 |
def _split_stack(self, n=None): |
277 |
- points, loss_improvements = zip(*reversed(self._stack.items())) |
|
278 |
- return points[:n], loss_improvements[:n] |
|
285 |
+ if self._stack: |
|
286 |
+ points, loss_improvements = zip(*reversed(self._stack.items())) |
|
287 |
+ return list(points[:n]), list(loss_improvements[:n]) |
|
288 |
+ else: |
|
289 |
+ return [], [] |
|
279 | 290 |
|
280 | 291 |
def _choose_and_add_points(self, n): |
281 |
- points = [] |
|
282 |
- loss_improvements = [] |
|
283 | 292 |
n_left = n |
293 |
+ points, loss_improvements = self._split_stack(n_left) |
|
294 |
+ self.add_data(points, itertools.repeat(None)) |
|
295 |
+ n_left -= len(points) |
|
296 |
+ |
|
284 | 297 |
while n_left > 0: |
285 | 298 |
# The while loop is needed because `stack_till` could be larger |
286 | 299 |
# than the number of triangles between the points. Therefore |
287 | 300 |
# it could fill up till a length smaller than `stack_till`. |
288 |
- if not any(p in self._stack for p in self._bounds_points): |
|
289 |
- self._fill_stack(stack_till=max(n_left, 10)) |
|
290 |
- new_points, new_loss_improvements = self._split_stack(n_left) |
|
301 |
+ new_points, new_loss_improvements = self._fill_stack(stack_till=max(n_left, 10)) |
|
291 | 302 |
points += new_points |
292 | 303 |
loss_improvements += new_loss_improvements |
293 |
- self.add_data(new_points, itertools.repeat(None)) |
|
294 | 304 |
n_left -= len(new_points) |
305 |
+ self.add_data(new_points, itertools.repeat(None)) |
|
295 | 306 |
|
296 |
- return points, loss_improvements |
|
307 |
+ return points[:n], loss_improvements[:n] |
|
297 | 308 |
|
298 | 309 |
def choose_points(self, n, add_data=True): |
299 | 310 |
if not add_data: |