Browse code

1D: simplify logic in bounds chosing

Bas Nijholt authored on 10/09/2018 12:58:12
Showing 1 changed files
... ...
@@ -264,19 +264,14 @@ class Learner1D(BaseLearner):
264 264
         missing_bounds = [b for b in self.bounds if b not in self.data
265 265
                           and b not in self.pending_points]
266 266
 
267
-        if len(missing_bounds) == 2:
268
-            # First time
267
+        if missing_bounds:
269 268
             loss_improvements = [np.inf] * n
270
-            points = np.linspace(*self.bounds, n).tolist()
271
-        elif len(missing_bounds) == 1:
272
-            loss_improvements = [np.inf] * n
273
-            points = np.linspace(*self.bounds, n + 1).tolist()
274
-            if missing_bounds[0] == self.bounds[1]:
275
-                # Second time, if we previously returned just self.bounds[0]
276
-                points = points[1:]
277
-            else:
278
-                # Rare case in which self.bounds[1] is present before self.bounds[1]
279
-                points = points[:1]
269
+            points = np.linspace(*self.bounds, n + 2 - len(missing_bounds)).tolist()
270
+            if len(missing_bounds) == 1:
271
+                # If we previously returned just self.bounds[0] we exclude that point.
272
+                # In the rare case in which self.bounds[1] is present before self.bounds[1]
273
+                # we exclude that point.
274
+                points = points[1:] if missing_bounds[0] == self.bounds[1] else points[:-1]
280 275
         else:
281 276
             def xs(x_left, x_right, n):
282 277
                 if n == 1:
... ...
@@ -311,7 +306,6 @@ class Learner1D(BaseLearner):
311 306
 
312 307
         return points, loss_improvements
313 308
 
314
-
315 309
     def plot(self):
316 310
         hv = ensure_holoviews()
317 311
         if not self.data: