Browse code

1D: also check for abs(dx) < self._dx_eps in ask

Bas Nijholt authored on 08/06/2018 23:05:56
Showing 1 changed files
... ...
@@ -288,7 +288,17 @@ class Learner1D(BaseLearner):
288 288
 
289 289
             # Calculate how many points belong to each interval.
290 290
             x_scale = self._scale[0]
291
-            quals = [(-loss if not math.isinf(loss) else (x0 - x1) / x_scale, (x0, x1), 1)
291
+
292
+            def _calc_quality(x0, x1, loss):
293
+                dx = x0 - x1
294
+                if abs(dx) < self._dx_eps:
295
+                    return 0
296
+                elif not math.isinf(loss):
297
+                    return -loss
298
+                else:
299
+                    return dx / x_scale
300
+
301
+            quals = [(_calc_quality(x0, x1, loss), (x0, x1), 1)
292 302
                      for ((x0, x1), loss) in self.losses_combined.items()]
293 303
 
294 304
             heapq.heapify(quals)