Browse code

fix the bug where intervals could be smaller than dx_eps

Bas Nijholt authored on 11/07/2018 02:01:28
Showing 1 changed files
... ...
@@ -278,39 +278,36 @@ class Learner1D(BaseLearner):
278 278
             loss_improvements = [np.inf] * n
279 279
             points = np.linspace(*self.bounds, n + 1)[1:].tolist()
280 280
         else:
281
-            def xs(x, n):
281
+            def xs(x_left, x_right, n):
282 282
                 if n == 1:
283
+                    # This is just an optimization
283 284
                     return []
284 285
                 else:
285
-                    step = (x[1] - x[0]) / n
286
-                    return [x[0] + step * i for i in range(1, n)]
286
+                    step = (x_right - x_left) / n
287
+                    return [x_left + step * i for i in range(1, n)]
287 288
 
288 289
             # Calculate how many points belong to each interval.
289 290
             x_scale = self._scale[0]
290 291
 
291 292
             quals = []
292 293
             for ((x_left, x_right), loss) in self.losses_combined.items():
293
-                dx = x_right - x_left
294
-                if abs(dx) < self._dx_eps:
295
-                    # The interval is too small and should not be subdivided
296
-                    quality = 0
297
-                elif not math.isinf(loss):
298
-                    quality = -loss
299
-                else:
300
-                    quality = -dx / x_scale
294
+                quality = -loss if not math.isinf(loss) else -(x_right - x_left) / x_scale
301 295
                 quals.append((quality, (x_left, x_right), 1))
302 296
 
303 297
             heapq.heapify(quals)
304 298
 
305 299
             for point_number in range(n):
306 300
                 quality, x, n = quals[0]
301
+                if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
302
+                    # The interval is too small and should not be subdivided
303
+                    quality = np.inf
307 304
                 heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
308 305
 
309
-            points = list(itertools.chain.from_iterable(xs(x, n)
310
-                          for quality, x, n in quals))
306
+            points = list(itertools.chain.from_iterable(
307
+                xs(*x, n) for quality, x, n in quals))
311 308
 
312 309
             loss_improvements = list(itertools.chain.from_iterable(
313
-                                     itertools.repeat(-quality, n-1)
310
+                                     itertools.repeat(-quality, n - 1)
314 311
                                      for quality, x, n in quals))
315 312
 
316 313
         if add_data: