... | ... |
@@ -435,6 +435,13 @@ class Learner2D(BaseLearner): |
435 | 435 |
triangle = ip.tri.points[ip.tri.vertices[jsimplex]] |
436 | 436 |
point_new = choose_point_in_triangle(triangle, max_badness=5) |
437 | 437 |
point_new = tuple(self._unscale(point_new)) |
438 |
+ |
|
439 |
+ # np.clip results in numerical precision problems |
|
440 |
+ # https://gitlab.kwant-project.org/qt/adaptive/issues/132 |
|
441 |
+ clip = lambda x, l, u: max(l, min(u, x)) |
|
442 |
+ point_new = (clip(point_new[0], *self.bounds[0]), |
|
443 |
+ clip(point_new[1], *self.bounds[1])) |
|
444 |
+ |
|
438 | 445 |
loss_new = losses[jsimplex] |
439 | 446 |
|
440 | 447 |
points_new.append(point_new) |
... | ... |
@@ -9,6 +9,7 @@ import random |
9 | 9 |
import numpy as np |
10 | 10 |
from scipy import interpolate |
11 | 11 |
import scipy.spatial |
12 |
+from sortedcontainers import SortedKeyList |
|
12 | 13 |
|
13 | 14 |
from adaptive.learner.base_learner import BaseLearner |
14 | 15 |
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly |
... | ... |
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None): |
91 | 92 |
distance_matrix = scipy.spatial.distance.squareform(distances) |
92 | 93 |
i, j = np.unravel_index(np.argmax(distance_matrix), |
93 | 94 |
distance_matrix.shape) |
94 |
- |
|
95 | 95 |
point = (simplex[i, :] + simplex[j, :]) / 2 |
96 | 96 |
|
97 | 97 |
if transform is not None: |
... | ... |
@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None): |
100 | 100 |
return point |
101 | 101 |
|
102 | 102 |
|
103 |
+def _simplex_evaluation_priority(key): |
|
104 |
+ # We round the loss to 8 digits such that losses |
|
105 |
+ # are equal up to numerical precision will be considered |
|
106 |
+ # to be equal. This is needed because we want the learner |
|
107 |
+ # to behave in a deterministic fashion. |
|
108 |
+ loss, simplex, subsimplex = key |
|
109 |
+ return -round(loss, ndigits=8), simplex, subsimplex or (0,) |
|
110 |
+ |
|
111 |
+ |
|
103 | 112 |
class LearnerND(BaseLearner): |
104 | 113 |
"""Learns and predicts a function 'f: ℝ^N → ℝ^M'. |
105 | 114 |
|
... | ... |
@@ -200,7 +209,7 @@ class LearnerND(BaseLearner): |
200 | 209 |
# so when popping an item, you should check that the simplex that has |
201 | 210 |
# been returned has not been deleted. This checking is done by |
202 | 211 |
# _pop_highest_existing_simplex |
203 |
- self._simplex_queue = [] # heap |
|
212 |
+ self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority) |
|
204 | 213 |
|
205 | 214 |
@property |
206 | 215 |
def npoints(self): |
... | ... |
@@ -344,9 +353,7 @@ class LearnerND(BaseLearner): |
344 | 353 |
subtriangulation = self._subtriangulations[simplex] |
345 | 354 |
for subsimplex in new_subsimplices: |
346 | 355 |
subloss = subtriangulation.volume(subsimplex) * loss_density |
347 |
- subloss = round(subloss, ndigits=8) |
|
348 |
- heapq.heappush(self._simplex_queue, |
|
349 |
- (-subloss, simplex, subsimplex)) |
|
356 |
+ self._simplex_queue.add((subloss, simplex, subsimplex)) |
|
350 | 357 |
|
351 | 358 |
def _ask_and_tell_pending(self, n=1): |
352 | 359 |
xs, losses = zip(*(self._ask() for _ in range(n))) |
... | ... |
@@ -386,7 +393,7 @@ class LearnerND(BaseLearner): |
386 | 393 |
# find the simplex with the highest loss, we do need to check that the |
387 | 394 |
# simplex hasn't been deleted yet |
388 | 395 |
while len(self._simplex_queue): |
389 |
- loss, simplex, subsimplex = heapq.heappop(self._simplex_queue) |
|
396 |
+ loss, simplex, subsimplex = self._simplex_queue.pop(0) |
|
390 | 397 |
if (subsimplex is None |
391 | 398 |
and simplex in self.tri.simplices |
392 | 399 |
and simplex not in self._subtriangulations): |
... | ... |
@@ -462,8 +469,7 @@ class LearnerND(BaseLearner): |
462 | 469 |
self._try_adding_pending_point_to_simplex(p, simplex) |
463 | 470 |
|
464 | 471 |
if simplex not in self._subtriangulations: |
465 |
- loss = round(loss, ndigits=8) |
|
466 |
- heapq.heappush(self._simplex_queue, (-loss, simplex, None)) |
|
472 |
+ self._simplex_queue.add((loss, simplex, None)) |
|
467 | 473 |
continue |
468 | 474 |
|
469 | 475 |
self._update_subsimplex_losses( |
... | ... |
@@ -488,7 +494,7 @@ class LearnerND(BaseLearner): |
488 | 494 |
return |
489 | 495 |
|
490 | 496 |
# reset the _simplex_queue |
491 |
- self._simplex_queue = [] |
|
497 |
+ self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority) |
|
492 | 498 |
|
493 | 499 |
# recompute all losses |
494 | 500 |
for simplex in self.tri.simplices: |
... | ... |
@@ -497,8 +503,7 @@ class LearnerND(BaseLearner): |
497 | 503 |
|
498 | 504 |
# now distribute it around the the children if they are present |
499 | 505 |
if simplex not in self._subtriangulations: |
500 |
- loss = round(loss, ndigits=8) |
|
501 |
- heapq.heappush(self._simplex_queue, (-loss, simplex, None)) |
|
506 |
+ self._simplex_queue.add((loss, simplex, None)) |
|
502 | 507 |
continue |
503 | 508 |
|
504 | 509 |
self._update_subsimplex_losses( |
... | ... |
@@ -362,9 +362,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(learner_type, f, lear |
362 | 362 |
|
363 | 363 |
# XXX: This *should* pass (https://gitlab.kwant-project.org/qt/adaptive/issues/84) |
364 | 364 |
# but we xfail it now, as Learner2D will be deprecated anyway |
365 |
-# The LearnerND fails sometimes, see |
|
366 |
-# https://gitlab.kwant-project.org/qt/adaptive/merge_requests/128#note_21807 |
|
367 |
-@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND)) |
|
365 |
+@run_with(Learner1D, xfail(Learner2D), LearnerND) |
|
368 | 366 |
def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner_kwargs): |
369 | 367 |
"""Learners behave identically under transformations that leave |
370 | 368 |
the loss invariant. |
... | ... |
@@ -392,6 +390,10 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner |
392 | 390 |
|
393 | 391 |
npoints = random.randrange(300, 500) |
394 | 392 |
|
393 |
+ if learner_type is LearnerND: |
|
394 |
+ # Because the LearnerND is slow |
|
395 |
+ npoints //= 10 |
|
396 |
+ |
|
395 | 397 |
for n in range(npoints): |
396 | 398 |
cxs, _ = control.ask(1) |
397 | 399 |
xs, _ = learner.ask(1) |