...
|
...
|
@@ -8,6 +8,7 @@ from collections import Iterable
|
8
|
8
|
|
9
|
9
|
import numpy as np
|
10
|
10
|
import sortedcontainers
|
|
11
|
+import sortedcollections
|
11
|
12
|
|
12
|
13
|
from adaptive.learner.base_learner import BaseLearner
|
13
|
14
|
from adaptive.learner.learnerND import volume
|
...
|
...
|
@@ -225,9 +226,6 @@ class Learner1D(BaseLearner):
|
225
|
226
|
|
226
|
227
|
self.loss_per_interval = loss_per_interval or default_loss
|
227
|
228
|
|
228
|
|
- # A dict storing the loss function for each interval x_n.
|
229
|
|
- self.losses = {}
|
230
|
|
- self.losses_combined = {}
|
231
|
229
|
|
232
|
230
|
# When the scale changes by a factor 2, the losses are
|
233
|
231
|
# recomputed. This is tunable such that we can test
|
...
|
...
|
@@ -249,6 +247,10 @@ class Learner1D(BaseLearner):
|
249
|
247
|
self._scale = [bounds[1] - bounds[0], 0]
|
250
|
248
|
self._oldscale = deepcopy(self._scale)
|
251
|
249
|
|
|
250
|
+ # A LossManager storing the loss function for each interval x_n.
|
|
251
|
+ self.losses = loss_manager(self._scale[0])
|
|
252
|
+ self.losses_combined = loss_manager(self._scale[0])
|
|
253
|
+
|
252
|
254
|
# The precision in 'x' below which we set losses to 0.
|
253
|
255
|
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
|
254
|
256
|
|
...
|
...
|
@@ -284,7 +286,10 @@ class Learner1D(BaseLearner):
|
284
|
286
|
@cache_latest
|
285
|
287
|
def loss(self, real=True):
|
286
|
288
|
losses = self.losses if real else self.losses_combined
|
287
|
|
- return max(losses.values()) if len(losses) > 0 else float('inf')
|
|
289
|
+ if not losses:
|
|
290
|
+ return np.inf
|
|
291
|
+ max_interval, max_loss = losses.peekitem(0)
|
|
292
|
+ return max_loss
|
288
|
293
|
|
289
|
294
|
def _scale_x(self, x):
|
290
|
295
|
if x is None:
|
...
|
...
|
@@ -454,8 +459,7 @@ class Learner1D(BaseLearner):
|
454
|
459
|
|
455
|
460
|
# If the scale has increased enough, recompute all losses.
|
456
|
461
|
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
|
457
|
|
-
|
458
|
|
- for interval in self.losses:
|
|
462
|
+ for interval in reversed(self.losses):
|
459
|
463
|
self._update_interpolated_loss_in_interval(*interval)
|
460
|
464
|
|
461
|
465
|
self._oldscale = deepcopy(self._scale)
|
...
|
...
|
@@ -504,18 +508,18 @@ class Learner1D(BaseLearner):
|
504
|
508
|
for neighbors in (self.neighbors, self.neighbors_combined)]
|
505
|
509
|
|
506
|
510
|
# The the losses for the "real" intervals.
|
507
|
|
- self.losses = {}
|
|
511
|
+ self.losses = loss_manager(self._scale[0])
|
508
|
512
|
for ival in intervals:
|
509
|
513
|
self.losses[ival] = self._get_loss_in_interval(*ival)
|
510
|
514
|
|
511
|
515
|
# List with "real" intervals that have interpolated intervals inside
|
512
|
516
|
to_interpolate = []
|
513
|
517
|
|
514
|
|
- self.losses_combined = {}
|
|
518
|
+ self.losses_combined = loss_manager(self._scale[0])
|
515
|
519
|
for ival in intervals_combined:
|
516
|
520
|
# If this interval exists in 'losses' then copy it otherwise
|
517
|
521
|
# calculate it.
|
518
|
|
- if ival in self.losses:
|
|
522
|
+ if ival in reversed(self.losses):
|
519
|
523
|
self.losses_combined[ival] = self.losses[ival]
|
520
|
524
|
else:
|
521
|
525
|
# Set all losses to inf now, later they might be udpdated if the
|
...
|
...
|
@@ -530,7 +534,7 @@ class Learner1D(BaseLearner):
|
530
|
534
|
to_interpolate.append((x_left, x_right))
|
531
|
535
|
|
532
|
536
|
for ival in to_interpolate:
|
533
|
|
- if ival in self.losses:
|
|
537
|
+ if ival in reversed(self.losses):
|
534
|
538
|
# If this interval does not exist it should already
|
535
|
539
|
# have an inf loss.
|
536
|
540
|
self._update_interpolated_loss_in_interval(*ival)
|
...
|
...
|
@@ -566,64 +570,57 @@ class Learner1D(BaseLearner):
|
566
|
570
|
if len(missing_bounds) >= n:
|
567
|
571
|
return missing_bounds[:n], [np.inf] * n
|
568
|
572
|
|
569
|
|
- def finite_loss(loss, xs):
|
570
|
|
- # If the loss is infinite we return the
|
571
|
|
- # distance between the two points.
|
572
|
|
- if math.isinf(loss):
|
573
|
|
- loss = (xs[1] - xs[0]) / self._scale[0]
|
574
|
|
-
|
575
|
|
- # We round the loss to 12 digits such that losses
|
576
|
|
- # are equal up to numerical precision will be considered
|
577
|
|
- # equal.
|
578
|
|
- return round(loss, ndigits=12)
|
579
|
|
-
|
580
|
|
- quals = [(-finite_loss(loss, x), x, 1)
|
581
|
|
- for x, loss in self.losses_combined.items()]
|
582
|
|
-
|
583
|
573
|
# Add bound intervals to quals if bounds were missing.
|
584
|
574
|
if len(self.data) + len(self.pending_points) == 0:
|
585
|
575
|
# We don't have any points, so return a linspace with 'n' points.
|
586
|
576
|
return np.linspace(*self.bounds, n).tolist(), [np.inf] * n
|
587
|
|
- elif len(missing_bounds) > 0:
|
|
577
|
+
|
|
578
|
+ quals = loss_manager(self._scale[0])
|
|
579
|
+ if len(missing_bounds) > 0:
|
588
|
580
|
# There is at least one point in between the bounds.
|
589
|
581
|
all_points = list(self.data.keys()) + list(self.pending_points)
|
590
|
582
|
intervals = [(self.bounds[0], min(all_points)),
|
591
|
583
|
(max(all_points), self.bounds[1])]
|
592
|
584
|
for interval, bound in zip(intervals, self.bounds):
|
593
|
585
|
if bound in missing_bounds:
|
594
|
|
- qual = (-finite_loss(np.inf, interval), interval, 1)
|
595
|
|
- quals.append(qual)
|
596
|
|
-
|
597
|
|
- # Calculate how many points belong to each interval.
|
598
|
|
- points, loss_improvements = self._subdivide_quals(
|
599
|
|
- quals, n - len(missing_bounds))
|
600
|
|
-
|
601
|
|
- points = missing_bounds + points
|
602
|
|
- loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
|
|
586
|
+ quals[(*interval, 1)] = np.inf
|
603
|
587
|
|
604
|
|
- return points, loss_improvements
|
|
588
|
+ points_to_go = n - len(missing_bounds)
|
605
|
589
|
|
606
|
|
- def _subdivide_quals(self, quals, n):
|
607
|
590
|
# Calculate how many points belong to each interval.
|
608
|
|
- heapq.heapify(quals)
|
609
|
|
-
|
610
|
|
- for _ in range(n):
|
611
|
|
- quality, x, n = quals[0]
|
612
|
|
- if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
|
613
|
|
- # The interval is too small and should not be subdivided.
|
614
|
|
- quality = np.inf
|
615
|
|
- # XXX: see https://gitlab.kwant-project.org/qt/adaptive/issues/104
|
616
|
|
- heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
|
|
591
|
+ i, i_max = 0, len(self.losses_combined)
|
|
592
|
+ for _ in range(points_to_go):
|
|
593
|
+ qual, loss_qual = quals.peekitem(0) if quals else (None, 0)
|
|
594
|
+ ival, loss_ival = self.losses_combined.peekitem(i) if i < i_max else (None, 0)
|
|
595
|
+
|
|
596
|
+ if (qual is None
|
|
597
|
+ or (ival is not None
|
|
598
|
+ and self._loss(self.losses_combined, ival)
|
|
599
|
+ >= self._loss(quals, qual))):
|
|
600
|
+ i += 1
|
|
601
|
+ quals[(*ival, 2)] = loss_ival / 2
|
|
602
|
+ else:
|
|
603
|
+ quals.pop(qual, None)
|
|
604
|
+ *xs, n = qual
|
|
605
|
+ quals[(*xs, n+1)] = loss_qual * n / (n+1)
|
617
|
606
|
|
618
|
607
|
points = list(itertools.chain.from_iterable(
|
619
|
|
- linspace(*interval, n) for quality, interval, n in quals))
|
|
608
|
+ linspace(*ival, n) for (*ival, n) in quals))
|
620
|
609
|
|
621
|
610
|
loss_improvements = list(itertools.chain.from_iterable(
|
622
|
|
- itertools.repeat(-quality, n - 1)
|
623
|
|
- for quality, interval, n in quals))
|
|
611
|
+ itertools.repeat(quals[x0, x1, n], n - 1)
|
|
612
|
+ for (x0, x1, n) in quals))
|
|
613
|
+
|
|
614
|
+ # add the missing bounds
|
|
615
|
+ points = missing_bounds + points
|
|
616
|
+ loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
|
624
|
617
|
|
625
|
618
|
return points, loss_improvements
|
626
|
619
|
|
|
620
|
+ def _loss(self, mapping, ival):
|
|
621
|
+ loss = mapping[ival]
|
|
622
|
+ return finite_loss(ival, loss, self._scale[0])
|
|
623
|
+
|
627
|
624
|
def plot(self):
|
628
|
625
|
"""Returns a plot of the evaluated data.
|
629
|
626
|
|
...
|
...
|
@@ -658,3 +655,42 @@ class Learner1D(BaseLearner):
|
658
|
655
|
|
659
|
656
|
def _set_data(self, data):
|
660
|
657
|
self.tell_many(*zip(*data.items()))
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+def _fix_deepcopy(sorted_dict, x_scale):
|
|
661
|
+ # XXX: until https://github.com/grantjenks/sortedcollections/issues/5 is fixed
|
|
662
|
+ import types
|
|
663
|
+ def __deepcopy__(self, memo):
|
|
664
|
+ items = deepcopy(list(self.items()))
|
|
665
|
+ lm = loss_manager(self.x_scale)
|
|
666
|
+ lm.update(items)
|
|
667
|
+ return lm
|
|
668
|
+ sorted_dict.x_scale = x_scale
|
|
669
|
+ sorted_dict.__deepcopy__ = types.MethodType(__deepcopy__, sorted_dict)
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+def loss_manager(x_scale):
|
|
673
|
+ def sort_key(ival, loss):
|
|
674
|
+ loss, ival = finite_loss(ival, loss, x_scale)
|
|
675
|
+ return -loss, ival
|
|
676
|
+ sorted_dict = sortedcollections.ItemSortedDict(sort_key)
|
|
677
|
+ _fix_deepcopy(sorted_dict, x_scale)
|
|
678
|
+ return sorted_dict
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+def finite_loss(ival, loss, x_scale):
|
|
682
|
+ """Get the socalled finite_loss of an interval in order to be able to
|
|
683
|
+ sort intervals that have infinite loss."""
|
|
684
|
+ # If the loss is infinite we return the
|
|
685
|
+ # distance between the two points.
|
|
686
|
+ if math.isinf(loss):
|
|
687
|
+ loss = (ival[1] - ival[0]) / x_scale
|
|
688
|
+ if len(ival) == 3:
|
|
689
|
+ # Used when constructing quals. Last item is
|
|
690
|
+ # the number of points inside the qual.
|
|
691
|
+ loss /= ival[2]
|
|
692
|
+
|
|
693
|
+ # We round the loss to 12 digits such that losses
|
|
694
|
+ # are equal up to numerical precision will be considered
|
|
695
|
+ # equal.
|
|
696
|
+ return round(loss, ndigits=12), ival
|