Browse code

2D: code cleanup

Bas Nijholt authored on 16/11/2017 19:42:55
Showing 1 changed files
... ...
@@ -108,10 +108,8 @@ class Learner2D(BaseLearner):
108 108
 
109 109
     def __init__(self, function, bounds, loss_per_triangle=None):
110 110
         self.ndim = len(bounds)
111
-        self.loss_per_triangle = loss_per_triangle or _default_loss_per_triangle
112 111
         self._vdim = None
113
-        if self.ndim != 2:
114
-            raise ValueError("Only 2-D sampling supported.")
112
+        self.loss_per_triangle = loss_per_triangle or _default_loss_per_triangle
115 113
         self.bounds = tuple((float(a), float(b)) for a, b in bounds)
116 114
         self._points = np.zeros([100, self.ndim])
117 115
         self._values = np.zeros([100, self.vdim], dtype=float)
... ...
@@ -122,9 +120,11 @@ class Learner2D(BaseLearner):
122 120
         xy_scale = np.ptp(self.bounds, axis=1)
123 121
 
124 122
         def scale(points):
123
+            points = np.asarray(points)
125 124
             return (points - xy_mean) / xy_scale
126 125
 
127 126
         def unscale(points):
127
+            points = np.asarray(points)
128 128
             return points * xy_scale + xy_mean
129 129
 
130 130
         self.scale = scale
... ...
@@ -166,6 +166,11 @@ class Learner2D(BaseLearner):
166 166
     def n_real(self):
167 167
         return self.n - len(self._interp)
168 168
 
169
+    @property
170
+    def bounds_are_done(self):
171
+        return not any((p in self._interp or p in self._stack)
172
+                       for p in self._bounds_points)
173
+
169 174
     def ip(self):
170 175
         points = self.scale(self.points)
171 176
         return interpolate.LinearNDInterpolator(points, self.values)
... ...
@@ -177,16 +182,10 @@ class Learner2D(BaseLearner):
177 182
         # Interpolate the unfinished points
178 183
         if self._interp:
179 184
             n_interp = list(self._interp.values())
180
-            bounds_are_done = not any(p in self._interp
181
-                                      for p in self._bounds_points)
182
-            if bounds_are_done:
185
+            if self.bounds_are_done:
183 186
                 values[n_interp] = self.ip()(points[n_interp])
184 187
             else:
185
-                # It is important not to return exact zeros because
186
-                # otherwise the algo will try to add the same point
187
-                # to the stack each time.
188
-                values[n_interp] = np.random.rand(
189
-                    len(n_interp), self.vdim) * 1e-15
188
+                values[n_interp] = np.zeros((len(n_interp), self.vdim))
190 189
 
191 190
         return interpolate.LinearNDInterpolator(points, values)
192 191
 
... ...
@@ -216,25 +215,20 @@ class Learner2D(BaseLearner):
216 215
 
217 216
         self._points[n] = point
218 217
 
219
-        try:
220
-            self._values[n] = value
221
-        except ValueError:
218
+        if self._vdim is None and hasattr(value, '__len__'):
222 219
             self._vdim = len(value)
223 220
             self._values = np.resize(self._values, (nmax, self.vdim))
224
-            self._values[n] = value
225 221
 
226
-        self._stack.pop(point, None)
222
+        self._values[n] = value
227 223
 
228
-    def _fill_stack(self, stack_till=None):
229
-        if stack_till is None:
230
-            stack_till = 1
224
+        self._stack.pop(point, None)
231 225
 
226
+    def _fill_stack(self, stack_till=1):
232 227
         if self.values_combined.shape[0] < self.ndim + 1:
233 228
             raise ValueError("too few points...")
234 229
 
235 230
         # Interpolate
236 231
         ip = self.ip_combined()
237
-        tri = ip.tri
238 232
 
239 233
         losses = self.loss_per_triangle(ip)
240 234
 
... ...
@@ -249,53 +243,42 @@ class Learner2D(BaseLearner):
249 243
             return False
250 244
 
251 245
         for j, _ in enumerate(losses):
252
-            # Estimate point of maximum curvature inside the simplex
253 246
             jsimplex = np.argmax(losses)
254
-            p = tri.points[tri.vertices[jsimplex]]
255
-            point_new = self.unscale(p.mean(axis=-2))
256
-
257
-            # XXX: not sure whether this is necessary it was there
258
-            # originally.
247
+            point_new = ip.tri.points[ip.tri.vertices[jsimplex]]
248
+            point_new = self.unscale(point_new.mean(axis=-2))
259 249
             point_new = np.clip(point_new, *zip(*self.bounds))
260 250
 
261 251
             # Check if it is really new
262 252
             if point_exists(point_new):
263
-                losses[jsimplex] = 0
253
+                losses[jsimplex] = -np.inf
264 254
                 continue
265 255
 
266
-            # Add to stack
267 256
             self._stack[tuple(point_new)] = losses[jsimplex]
268 257
 
269 258
             if len(self._stack) >= stack_till:
270 259
                 break
271 260
             else:
272
-                losses[jsimplex] = 0
261
+                losses[jsimplex] = -np.inf
273 262
 
274 263
     def _split_stack(self, n=None):
275 264
         points, loss_improvements = zip(*reversed(self._stack.items()))
276 265
         return points[:n], loss_improvements[:n]
277 266
 
278 267
     def _choose_and_add_points(self, n):
279
-        if n <= len(self._stack):
280
-            points, loss_improvements = self._split_stack(n)
281
-            self.add_data(points, itertools.repeat(None))
282
-        else:
283
-            points = []
284
-            loss_improvements = []
285
-            n_left = n
286
-            while n_left > 0:
287
-                # The while loop is needed because `stack_till` could be larger
288
-                # than the number of triangles between the points. Therefore
289
-                # it could fill up till a length smaller than `stack_till`.
290
-                no_bounds_in_stack = not any(p in self._stack
291
-                                             for p in self._bounds_points)
292
-                if no_bounds_in_stack:
293
-                    self._fill_stack(stack_till=n_left)
294
-                new_points, new_loss_improvements = self._split_stack(n_left)
295
-                points += new_points
296
-                loss_improvements += new_loss_improvements
297
-                self.add_data(new_points, itertools.repeat(None))
298
-                n_left -= len(new_points)
268
+        points = []
269
+        loss_improvements = []
270
+        n_left = n
271
+        while n_left > 0:
272
+            # The while loop is needed because `stack_till` could be larger
273
+            # than the number of triangles between the points. Therefore
274
+            # it could fill up till a length smaller than `stack_till`.
275
+            if not any(p in self._stack for p in self._bounds_points):
276
+                self._fill_stack(stack_till=n_left)
277
+            new_points, new_loss_improvements = self._split_stack(n_left)
278
+            points += new_points
279
+            loss_improvements += new_loss_improvements
280
+            self.add_data(new_points, itertools.repeat(None))
281
+            n_left -= len(new_points)
299 282
 
300 283
         return points, loss_improvements
301 284
 
... ...
@@ -307,10 +290,7 @@ class Learner2D(BaseLearner):
307 290
             return self._choose_and_add_points(n)
308 291
 
309 292
     def loss(self, real=True):
310
-        n = self.n_real if real else self.n
311
-        bounds_are_not_done = any(p in self._interp
312
-                                  for p in self._bounds_points)
313
-        if n <= 4 or bounds_are_not_done:
293
+        if not self.bounds_are_done:
314 294
             return np.inf
315 295
         ip = self.ip() if real else self.ip_combined()
316 296
         losses = self.loss_per_triangle(ip)