Browse code

Resolve "(Learner1D) improve time complexity"

Jorn Hoofwijk authored on 07/12/2018 11:42:53 • Bas Nijholt committed on 07/12/2018 11:42:53
Showing 3 changed files
... ...
@@ -8,6 +8,7 @@ from collections import Iterable
8 8
 
9 9
 import numpy as np
10 10
 import sortedcontainers
11
+import sortedcollections
11 12
 
12 13
 from adaptive.learner.base_learner import BaseLearner
13 14
 from adaptive.learner.learnerND import volume
... ...
@@ -225,9 +226,6 @@ class Learner1D(BaseLearner):
225 226
 
226 227
         self.loss_per_interval = loss_per_interval or default_loss
227 228
 
228
-        # A dict storing the loss function for each interval x_n.
229
-        self.losses = {}
230
-        self.losses_combined = {}
231 229
 
232 230
         # When the scale changes by a factor 2, the losses are
233 231
         # recomputed. This is tunable such that we can test
... ...
@@ -249,6 +247,10 @@ class Learner1D(BaseLearner):
249 247
         self._scale = [bounds[1] - bounds[0], 0]
250 248
         self._oldscale = deepcopy(self._scale)
251 249
 
250
+        # A LossManager storing the loss function for each interval x_n.
251
+        self.losses = loss_manager(self._scale[0])
252
+        self.losses_combined = loss_manager(self._scale[0])
253
+
252 254
         # The precision in 'x' below which we set losses to 0.
253 255
         self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
254 256
 
... ...
@@ -284,7 +286,10 @@ class Learner1D(BaseLearner):
284 286
     @cache_latest
285 287
     def loss(self, real=True):
286 288
         losses = self.losses if real else self.losses_combined
287
-        return max(losses.values()) if len(losses) > 0 else float('inf')
289
+        if not losses:
290
+            return np.inf
291
+        max_interval, max_loss = losses.peekitem(0)
292
+        return max_loss
288 293
 
289 294
     def _scale_x(self, x):
290 295
         if x is None:
... ...
@@ -454,8 +459,7 @@ class Learner1D(BaseLearner):
454 459
 
455 460
         # If the scale has increased enough, recompute all losses.
456 461
         if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
457
-
458
-            for interval in self.losses:
462
+            for interval in reversed(self.losses):
459 463
                 self._update_interpolated_loss_in_interval(*interval)
460 464
 
461 465
             self._oldscale = deepcopy(self._scale)
... ...
@@ -504,18 +508,18 @@ class Learner1D(BaseLearner):
504 508
             for neighbors in (self.neighbors, self.neighbors_combined)]
505 509
 
506 510
         # The the losses for the "real" intervals.
507
-        self.losses = {}
511
+        self.losses = loss_manager(self._scale[0])
508 512
         for ival in intervals:
509 513
             self.losses[ival] = self._get_loss_in_interval(*ival)
510 514
 
511 515
         # List with "real" intervals that have interpolated intervals inside
512 516
         to_interpolate = []
513 517
 
514
-        self.losses_combined = {}
518
+        self.losses_combined = loss_manager(self._scale[0])
515 519
         for ival in intervals_combined:
516 520
             # If this interval exists in 'losses' then copy it otherwise
517 521
             # calculate it.
518
-            if ival in self.losses:
522
+            if ival in reversed(self.losses):
519 523
                 self.losses_combined[ival] = self.losses[ival]
520 524
             else:
521 525
                 # Set all losses to inf now, later they might be udpdated if the
... ...
@@ -530,7 +534,7 @@ class Learner1D(BaseLearner):
530 534
                     to_interpolate.append((x_left, x_right))
531 535
 
532 536
         for ival in to_interpolate:
533
-            if ival in self.losses:
537
+            if ival in reversed(self.losses):
534 538
                 # If this interval does not exist it should already
535 539
                 # have an inf loss.
536 540
                 self._update_interpolated_loss_in_interval(*ival)
... ...
@@ -566,64 +570,57 @@ class Learner1D(BaseLearner):
566 570
         if len(missing_bounds) >= n:
567 571
             return missing_bounds[:n], [np.inf] * n
568 572
 
569
-        def finite_loss(loss, xs):
570
-            # If the loss is infinite we return the
571
-            # distance between the two points.
572
-            if math.isinf(loss):
573
-                loss = (xs[1] - xs[0]) / self._scale[0]
574
-
575
-            # We round the loss to 12 digits such that losses
576
-            # are equal up to numerical precision will be considered
577
-            # equal.
578
-            return round(loss, ndigits=12)
579
-
580
-        quals = [(-finite_loss(loss, x), x, 1)
581
-                 for x, loss in self.losses_combined.items()]
582
-
583 573
         # Add bound intervals to quals if bounds were missing.
584 574
         if len(self.data) + len(self.pending_points) == 0:
585 575
             # We don't have any points, so return a linspace with 'n' points.
586 576
             return np.linspace(*self.bounds, n).tolist(), [np.inf] * n
587
-        elif len(missing_bounds) > 0:
577
+
578
+        quals = loss_manager(self._scale[0])
579
+        if len(missing_bounds) > 0:
588 580
             # There is at least one point in between the bounds.
589 581
             all_points = list(self.data.keys()) + list(self.pending_points)
590 582
             intervals = [(self.bounds[0], min(all_points)),
591 583
                          (max(all_points), self.bounds[1])]
592 584
             for interval, bound in zip(intervals, self.bounds):
593 585
                 if bound in missing_bounds:
594
-                    qual = (-finite_loss(np.inf, interval), interval, 1)
595
-                    quals.append(qual)
596
-
597
-        # Calculate how many points belong to each interval.
598
-        points, loss_improvements = self._subdivide_quals(
599
-            quals, n - len(missing_bounds))
600
-
601
-        points = missing_bounds + points
602
-        loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
586
+                    quals[(*interval, 1)] = np.inf
603 587
 
604
-        return points, loss_improvements
588
+        points_to_go = n - len(missing_bounds)
605 589
 
606
-    def _subdivide_quals(self, quals, n):
607 590
         # Calculate how many points belong to each interval.
608
-        heapq.heapify(quals)
609
-
610
-        for _ in range(n):
611
-            quality, x, n = quals[0]
612
-            if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
613
-                # The interval is too small and should not be subdivided.
614
-                quality = np.inf
615
-                # XXX: see https://gitlab.kwant-project.org/qt/adaptive/issues/104
616
-            heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
591
+        i, i_max = 0, len(self.losses_combined)
592
+        for _ in range(points_to_go):
593
+            qual, loss_qual = quals.peekitem(0) if quals else (None, 0)
594
+            ival, loss_ival = self.losses_combined.peekitem(i) if i < i_max else (None, 0)
595
+
596
+            if (qual is None
597
+                or (ival is not None
598
+                    and self._loss(self.losses_combined, ival)
599
+                        >= self._loss(quals, qual))):
600
+                i += 1
601
+                quals[(*ival, 2)] = loss_ival / 2
602
+            else:
603
+                quals.pop(qual, None)
604
+                *xs, n = qual
605
+                quals[(*xs, n+1)] = loss_qual * n / (n+1)
617 606
 
618 607
         points = list(itertools.chain.from_iterable(
619
-            linspace(*interval, n) for quality, interval, n in quals))
608
+            linspace(*ival, n) for (*ival, n) in quals))
620 609
 
621 610
         loss_improvements = list(itertools.chain.from_iterable(
622
-            itertools.repeat(-quality, n - 1)
623
-            for quality, interval, n in quals))
611
+            itertools.repeat(quals[x0, x1, n], n - 1)
612
+            for (x0, x1, n)  in quals))
613
+
614
+        # add the missing bounds
615
+        points = missing_bounds + points
616
+        loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
624 617
 
625 618
         return points, loss_improvements
626 619
 
620
+    def _loss(self, mapping, ival):
621
+        loss = mapping[ival]
622
+        return finite_loss(ival, loss, self._scale[0])
623
+
627 624
     def plot(self):
628 625
         """Returns a plot of the evaluated data.
629 626
 
... ...
@@ -658,3 +655,42 @@ class Learner1D(BaseLearner):
658 655
 
659 656
     def _set_data(self, data):
660 657
         self.tell_many(*zip(*data.items()))
658
+
659
+
660
+def _fix_deepcopy(sorted_dict, x_scale):
661
+    # XXX: until https://github.com/grantjenks/sortedcollections/issues/5 is fixed
662
+    import types
663
+    def __deepcopy__(self, memo):
664
+        items = deepcopy(list(self.items()))
665
+        lm = loss_manager(self.x_scale)
666
+        lm.update(items)
667
+        return lm
668
+    sorted_dict.x_scale = x_scale
669
+    sorted_dict.__deepcopy__ = types.MethodType(__deepcopy__, sorted_dict)
670
+
671
+
672
+def loss_manager(x_scale):
673
+    def sort_key(ival, loss):
674
+        loss, ival = finite_loss(ival, loss, x_scale)
675
+        return -loss, ival
676
+    sorted_dict = sortedcollections.ItemSortedDict(sort_key)
677
+    _fix_deepcopy(sorted_dict, x_scale)
678
+    return sorted_dict
679
+
680
+
681
+def finite_loss(ival, loss, x_scale):
682
+    """Get the socalled finite_loss of an interval in order to be able to
683
+    sort intervals that have infinite loss."""
684
+    # If the loss is infinite we return the
685
+    # distance between the two points.
686
+    if math.isinf(loss):
687
+        loss = (ival[1] - ival[0]) / x_scale
688
+        if len(ival) == 3:
689
+            # Used when constructing quals. Last item is
690
+            # the number of points inside the qual.
691
+            loss /= ival[2]
692
+
693
+    # We round the loss to 12 digits such that losses
694
+    # are equal up to numerical precision will be considered
695
+    # equal.
696
+    return round(loss, ndigits=12), ival
... ...
@@ -6,6 +6,7 @@ channels:
6 6
 dependencies:
7 7
   - python=3.6
8 8
   - sortedcontainers
9
+  - sortedcollections
9 10
   - scipy
10 11
   - holoviews
11 12
   - ipyparallel
... ...
@@ -26,6 +26,7 @@ version, cmdclass = get_version_and_cmdclass('adaptive')
26 26
 
27 27
 install_requires = [
28 28
     'scipy',
29
+    'sortedcollections',
29 30
     'sortedcontainers',
30 31
 ]
31 32