Browse code

Merge branch 'stable-0.7' which is tagged as v0.7.4

Bas Nijholt authored on 19/02/2019 14:27:52
Showing 3 changed files
... ...
@@ -129,7 +129,7 @@ def curvature_loss_function(area_factor=1, euclid_factor=0.02, horizontal_factor
129 129
     @uses_nth_neighbors(1)
130 130
     def curvature_loss(xs, ys):
131 131
         xs_middle = xs[1:3]
132
-        ys_middle = xs[1:3]
132
+        ys_middle = ys[1:3]
133 133
 
134 134
         triangle_loss_ = triangle_loss(xs, ys)
135 135
         default_loss_ = default_loss(xs_middle, ys_middle)
... ...
@@ -654,7 +654,8 @@ class Learner1D(BaseLearner):
654 654
         return self.data
655 655
 
656 656
     def _set_data(self, data):
657
-        self.tell_many(*zip(*data.items()))
657
+        if data:
658
+            self.tell_many(*zip(*data.items()))
658 659
 
659 660
 
660 661
 def loss_manager(x_scale):
... ...
@@ -670,7 +671,7 @@ def finite_loss(ival, loss, x_scale):
670 671
     sort intervals that have infinite loss."""
671 672
     # If the loss is infinite we return the
672 673
     # distance between the two points.
673
-    if math.isinf(loss):
674
+    if math.isinf(loss) or math.isnan(loss):
674 675
         loss = (ival[1] - ival[0]) / x_scale
675 676
         if len(ival) == 3:
676 677
             # Used when constructing quals. Last item is
... ...
@@ -757,7 +757,8 @@ class LearnerND(BaseLearner):
757 757
         return self.data
758 758
 
759 759
     def _set_data(self, data):
760
-        self.tell_many(*zip(*data.items()))
760
+        if data:
761
+            self.tell_many(*zip(*data.items()))
761 762
 
762 763
     def _get_iso(self, level=0.0, which='surface'):
763 764
         if which == 'surface':
... ...
@@ -363,3 +363,15 @@ def test_curvature_loss_vectors():
363 363
     learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
364 364
     simple(learner, goal=lambda l: l.npoints > 100)
365 365
     assert learner.npoints > 100
366
+
367
+
368
+def test_NaN_loss():
369
+    # see https://github.com/python-adaptive/adaptive/issues/145
370
+    def f(x):
371
+        a = 0.01
372
+        if random.random() < 0.2:
373
+            return np.NaN
374
+        return x + a**2 / (a**2 + x**2)
375
+
376
+    learner = Learner1D(f, bounds=(-1, 1))
377
+    simple(learner, lambda l: l.npoints > 100)