Browse code

Merge branch 'stable-0.7'

Joseph Weston authored on 17/12/2018 15:45:33
Showing 4 changed files
... ...
@@ -435,6 +435,13 @@ class Learner2D(BaseLearner):
435 435
             triangle = ip.tri.points[ip.tri.vertices[jsimplex]]
436 436
             point_new = choose_point_in_triangle(triangle, max_badness=5)
437 437
             point_new = tuple(self._unscale(point_new))
438
+
439
+            # np.clip results in numerical precision problems
440
+            # https://gitlab.kwant-project.org/qt/adaptive/issues/132
441
+            clip = lambda x, l, u: max(l, min(u, x))
442
+            point_new = (clip(point_new[0], *self.bounds[0]),
443
+                         clip(point_new[1], *self.bounds[1]))
444
+
438 445
             loss_new = losses[jsimplex]
439 446
 
440 447
             points_new.append(point_new)
... ...
@@ -9,6 +9,7 @@ import random
9 9
 import numpy as np
10 10
 from scipy import interpolate
11 11
 import scipy.spatial
12
+from sortedcontainers import SortedKeyList
12 13
 
13 14
 from adaptive.learner.base_learner import BaseLearner
14 15
 from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
... ...
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
91 92
         distance_matrix = scipy.spatial.distance.squareform(distances)
92 93
         i, j = np.unravel_index(np.argmax(distance_matrix),
93 94
                                 distance_matrix.shape)
94
-
95 95
         point = (simplex[i, :] + simplex[j, :]) / 2
96 96
 
97 97
     if transform is not None:
... ...
@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None):
100 100
     return point
101 101
 
102 102
 
103
+def _simplex_evaluation_priority(key):
104
+    # We round the loss to 8 digits such that losses
105
+    # are equal up to numerical precision will be considered
106
+    # to be equal. This is needed because we want the learner
107
+    # to behave in a deterministic fashion.
108
+    loss, simplex, subsimplex = key
109
+    return -round(loss, ndigits=8), simplex, subsimplex or (0,)
110
+
111
+
103 112
 class LearnerND(BaseLearner):
104 113
     """Learns and predicts a function 'f: ℝ^N → ℝ^M'.
105 114
 
... ...
@@ -200,7 +209,7 @@ class LearnerND(BaseLearner):
200 209
         # so when popping an item, you should check that the simplex that has
201 210
         # been returned has not been deleted. This checking is done by
202 211
         # _pop_highest_existing_simplex
203
-        self._simplex_queue = []  # heap
212
+        self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
204 213
 
205 214
     @property
206 215
     def npoints(self):
... ...
@@ -344,9 +353,7 @@ class LearnerND(BaseLearner):
344 353
         subtriangulation = self._subtriangulations[simplex]
345 354
         for subsimplex in new_subsimplices:
346 355
             subloss = subtriangulation.volume(subsimplex) * loss_density
347
-            subloss = round(subloss, ndigits=8)
348
-            heapq.heappush(self._simplex_queue,
349
-                           (-subloss, simplex, subsimplex))
356
+            self._simplex_queue.add((subloss, simplex, subsimplex))
350 357
 
351 358
     def _ask_and_tell_pending(self, n=1):
352 359
         xs, losses = zip(*(self._ask() for _ in range(n)))
... ...
@@ -386,7 +393,7 @@ class LearnerND(BaseLearner):
386 393
         # find the simplex with the highest loss, we do need to check that the
387 394
         # simplex hasn't been deleted yet
388 395
         while len(self._simplex_queue):
389
-            loss, simplex, subsimplex = heapq.heappop(self._simplex_queue)
396
+            loss, simplex, subsimplex = self._simplex_queue.pop(0)
390 397
             if (subsimplex is None
391 398
                 and simplex in self.tri.simplices
392 399
                 and simplex not in self._subtriangulations):
... ...
@@ -462,8 +469,7 @@ class LearnerND(BaseLearner):
462 469
                 self._try_adding_pending_point_to_simplex(p, simplex)
463 470
 
464 471
             if simplex not in self._subtriangulations:
465
-                loss = round(loss, ndigits=8)
466
-                heapq.heappush(self._simplex_queue, (-loss, simplex, None))
472
+                self._simplex_queue.add((loss, simplex, None))
467 473
                 continue
468 474
 
469 475
             self._update_subsimplex_losses(
... ...
@@ -488,7 +494,7 @@ class LearnerND(BaseLearner):
488 494
             return
489 495
 
490 496
         # reset the _simplex_queue
491
-        self._simplex_queue = []
497
+        self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
492 498
 
493 499
         # recompute all losses
494 500
         for simplex in self.tri.simplices:
... ...
@@ -497,8 +503,7 @@ class LearnerND(BaseLearner):
497 503
 
498 504
             # now distribute it around the the children if they are present
499 505
             if simplex not in self._subtriangulations:
500
-                loss = round(loss, ndigits=8)
501
-                heapq.heappush(self._simplex_queue, (-loss, simplex, None))
506
+                self._simplex_queue.add((loss, simplex, None))
502 507
                 continue
503 508
 
504 509
             self._update_subsimplex_losses(
... ...
@@ -362,9 +362,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(learner_type, f, lear
362 362
 
363 363
 # XXX: This *should* pass (https://gitlab.kwant-project.org/qt/adaptive/issues/84)
364 364
 #      but we xfail it now, as Learner2D will be deprecated anyway
365
-# The LearnerND fails sometimes, see
366
-# https://gitlab.kwant-project.org/qt/adaptive/merge_requests/128#note_21807
367
-@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND))
365
+@run_with(Learner1D, xfail(Learner2D), LearnerND)
368 366
 def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner_kwargs):
369 367
     """Learners behave identically under transformations that leave
370 368
        the loss invariant.
... ...
@@ -392,6 +390,10 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner
392 390
 
393 391
     npoints = random.randrange(300, 500)
394 392
 
393
+    if learner_type is LearnerND:
394
+        # Because the LearnerND is slow
395
+        npoints //= 10
396
+
395 397
     for n in range(npoints):
396 398
         cxs, _ = control.ask(1)
397 399
         xs, _ = learner.ask(1)
... ...
@@ -27,7 +27,7 @@ version, cmdclass = get_version_and_cmdclass('adaptive')
27 27
 install_requires = [
28 28
     'scipy',
29 29
     'sortedcollections',
30
-    'sortedcontainers',
30
+    'sortedcontainers >= 2.0',
31 31
 ]
32 32
 
33 33
 extras_require = {