Browse code

2D: implement choose_points with add_data=False

Bas Nijholt authored on 24/11/2017 14:49:51
Showing 1 changed files
... ...
@@ -1,6 +1,6 @@
1 1
 # -*- coding: utf-8 -*-
2 2
 import collections
3
-from copy import deepcopy
3
+from copy import copy, deepcopy
4 4
 import itertools
5 5
 import math
6 6
 
... ...
@@ -217,9 +217,6 @@ class Learner2D(BaseLearner):
217 217
 
218 218
     def ip_combined(self):
219 219
         if self._ip_combined is None:
220
-            points = self.scale(self.points_combined)
221
-            values = self.values_combined
222
-
223 220
             # Interpolate the unfinished points
224 221
             if self._interp:
225 222
                 points_interp = list(self._interp)
... ...
@@ -231,9 +228,10 @@ class Learner2D(BaseLearner):
231 228
                 for point, value in zip(points_interp, values_interp):
232 229
                     self.data_combined[point] = value
233 230
 
234
-            points_combined = self.scale(self.points_combined)
235
-            self._ip_combined = interpolate.LinearNDInterpolator(points_combined,
236
-                                                                 self.values_combined)
231
+            points = self.scale(self.points_combined)
232
+            values = self.values_combined
233
+            self._ip_combined = interpolate.LinearNDInterpolator(points,
234
+                                                                 values)
237 235
         return self._ip_combined
238 236
 
239 237
     def add_point(self, point, value):
... ...
@@ -288,8 +286,9 @@ class Learner2D(BaseLearner):
288 286
         else:
289 287
             return [], []
290 288
 
291
-    def _choose_and_add_points(self, n):
289
+    def choose_points(self, n, add_data=True):
292 290
         n_left = n
291
+        old_stack = copy(self._stack)
293 292
         points, loss_improvements = self._split_stack(n_left)
294 293
         self.add_data(points, itertools.repeat(None))
295 294
         n_left -= len(points)
... ...
@@ -299,19 +298,22 @@ class Learner2D(BaseLearner):
299 298
             # than the number of triangles between the points. Therefore
300 299
             # it could fill up till a length smaller than `stack_till`.
301 300
             new_points, new_loss_improvements = self._fill_stack(stack_till=max(n_left, 10))
302
-            points += new_points
303
-            loss_improvements += new_loss_improvements
304 301
             n_left -= len(new_points)
305 302
             self.add_data(new_points, itertools.repeat(None))
306 303
 
307
-        return points[:n], loss_improvements[:n]
304
+            points += new_points
305
+            loss_improvements += new_loss_improvements
308 306
 
309
-    def choose_points(self, n, add_data=True):
310
-        if not add_data:
311
-            with restore(self):
312
-                return self._choose_and_add_points(n)
307
+        if add_data:
308
+            for point, loss_improvement in zip(points[n:], loss_improvements[n:]):
309
+                self._stack[point] = loss_improvement
313 310
         else:
314
-            return self._choose_and_add_points(n)
311
+            self._stack = old_stack
312
+            for point in points:
313
+                self.data_combined.pop(point)
314
+                self._interp.remove(point)
315
+
316
+        return points[:n], loss_improvements[:n]
315 317
 
316 318
     def loss(self, real=True):
317 319
         if not self.bounds_are_done: