Browse code

improve leaf counting algorithm

Anton Akhmerov authored on 29/11/2017 09:22:38 • Bas Nijholt committed on 29/11/2017 11:52:49
Showing 1 changed files
... ...
@@ -90,9 +90,6 @@ class Interval:
90 90
         The intervals resulting from a split or refinement.
91 91
     done_points : dict
92 92
         A dictionary with the x-values and y-values: `{x1: y1, x2: y2 ...}`.
93
-    est_err : float
94
-        The sum of the errors of the children, if one of the children is not ready yet,
95
-        the error is infinity.
96 93
     discard : bool
97 94
         If True, the interval and it's children are not participating in the
98 95
         determination of the total integral anymore because its parent had a
... ...
@@ -103,14 +100,16 @@ class Interval:
103 100
         that the integral value has been calculated, see `self.done`.
104 101
     done : bool
105 102
         The integral and the error for the interval has been calculated.
106
-    branch_complete : bool
107
-        The interval can be used to determine the total integral, however if its children are
108
-        also `branch_complete`, they should be used.
109
-
103
+    done_leaves : set or None
104
+        Leaves used for the error and the integral estimation of this
105
+        interval. None means that this information was already propagated to
106
+        the ancestors of this interval.
110 107
     """
111 108
 
112
-    __slots__ = ['a', 'b', 'c', 'depth', 'fx', 'igral', 'err', 'rdepth',
113
-                 'ndiv', 'parent', 'children', 'done_points', 'est_err', 'discard']
109
+    __slots__ = [
110
+        'a', 'b', 'c', 'depth', 'fx', 'igral', 'err', 'rdepth',
111
+        'ndiv', 'parent', 'children', 'done_points', 'discard', 'done_leaves',
112
+    ]
114 113
 
115 114
     def __init__(self, a, b):
116 115
         self.children = []
... ...
@@ -118,9 +117,9 @@ class Interval:
118 117
         self.a = a
119 118
         self.b = b
120 119
         self.c = np.zeros((4, ns[3]))
121
-        self.est_err = np.inf
122 120
         self.discard = False
123 121
         self.igral = None
122
+        self.done_leaves = set()
124 123
 
125 124
     @classmethod
126 125
     def make_first(cls, a, b, depth=2):
... ...
@@ -142,13 +141,6 @@ class Interval:
142 141
         """The interval is complete and has the intergral calculated."""
143 142
         return hasattr(self, 'fx') and self.complete
144 143
 
145
-    @property
146
-    def branch_complete(self):
147
-        if not self.children and self.complete:
148
-            return True
149
-        else:
150
-            return np.isfinite(sum(i.est_err for i in self.children))
151
-
152 144
     @property
153 145
     def T(self):
154 146
         """Get the correct shift matrix.
... ...
@@ -195,26 +187,40 @@ class Interval:
195 187
 
196 188
     def complete_process(self):
197 189
         """Calculate the integral contribution and error from this interval,
198
-        and update the estimated error of all ancestor intervals."""
190
+        and update the done leaves of all ancestor intervals."""
199 191
         force_split = False
200 192
         if self.parent is not None and self.rdepth > self.parent.rdepth:
201 193
             self.process_split()
202 194
         else:
203 195
             force_split = self.process_refine()
204 196
 
205
-        # Set the estimated error
206
-        if np.isinf(self.est_err):
207
-            self.est_err = self.err
197
+        if self.done_leaves is not None and not len(self.done_leaves):
198
+            # This interval contributes to the integral estimate.
199
+            self.done_leaves = {self}
200
+
201
+        # Use this interval in the integral estimates of the ancestors while
202
+        # possible.
208 203
         ival = self.parent
204
+        old_leaves = set()
209 205
         while ival is not None:
210
-            # update the error estimates on all ancestor intervals
211
-            children_err = sum(i.est_err for i in ival.children)
212
-            if np.isfinite(children_err):
213
-                ival.est_err = children_err
214
-                ival = ival.parent
215
-            else:
206
+            unused_children = [child for child in ival.children
207
+                               if child.done_leaves is not None]
208
+
209
+            if not all(len(child.done_leaves) for child in unused_children):
216 210
                 break
217 211
 
212
+            if ival.done_leaves is None:
213
+                ival.done_leaves = set()
214
+            old_leaves.add(ival)
215
+            for child in ival.children:
216
+                if child.done_leaves is None:
217
+                    continue
218
+                ival.done_leaves |= child.done_leaves
219
+                child.done_leaves = None
220
+            ival.done_leaves -= old_leaves
221
+            ival = ival.parent
222
+
223
+
218 224
         # Check whether the point spacing is smaller than machine precision
219 225
         # and pop the interval with the largest error and do not split
220 226
         remove = self.err < (abs(self.igral) * eps * Vcond[self.depth])
... ...
@@ -266,7 +272,6 @@ class Interval:
266 272
             'rdepth={}'.format(self.rdepth),
267 273
             'err={:.5E}'.format(self.err),
268 274
             'igral={:.5E}'.format(self.igral if self.igral else 0),
269
-            'est_err={:.5E}'.format(self.est_err),
270 275
             'discard={}'.format(self.discard),
271 276
         ]
272 277
         return ' '.join(lst)
... ...
@@ -458,58 +463,20 @@ class IntegratorLearner(BaseLearner):
458 463
 
459 464
         return self._stack
460 465
 
461
-    @staticmethod
462
-    def deepest_complete_branches(ival):
463
-        """Finds the deepest complete set of intervals starting from `ival`."""
464
-        complete_branches = []
465
-        def _find_deepest(ival):
466
-            children_err = (sum(i.est_err for i in ival.children)
467
-                            if ival.children else np.inf)
468
-            if np.isfinite(ival.est_err) and np.isinf(children_err):
469
-                complete_branches.append(ival)
470
-            else:
471
-                for i in ival.children:
472
-                    _find_deepest(i)
473
-        _find_deepest(ival)
474
-        return complete_branches
475
-
476
-    @property
477
-    def complete_branches(self):
478
-        if not self.first_ival.done:
479
-            return []
480
-
481
-        if not self._complete_branches:
482
-            self._complete_branches.append(self.first_ival)
483
-
484
-        complete_branches = []
485
-        for ival in self._complete_branches:
486
-            if ival.discard:
487
-                complete_branches = self.deepest_complete_branches(self.first_ival)
488
-                break
489
-            if not ival.children:
490
-                # If the interval has no children, than is already is the deepest
491
-                # complete branch.
492
-                complete_branches.append(ival)
493
-            else:
494
-                complete_branches.extend(self.deepest_complete_branches(ival))
495
-        self._complete_branches = complete_branches
496
-        return self._complete_branches
497
-
498 466
     @property
499 467
     def nr_points(self):
500 468
         return len(self.done_points)
501 469
 
502 470
     @property
503 471
     def igral(self):
504
-        return sum(i.igral for i in self.complete_branches)
472
+        return sum(i.igral for i in self.first_ival.done_leaves)
505 473
 
506 474
     @property
507 475
     def err(self):
508
-        complete_branches = self.complete_branches
509
-        if not complete_branches:
510
-            return np.inf
476
+        if self.first_ival.done_leaves:
477
+            return sum(i.err for i in self.first_ival.done_leaves)
511 478
         else:
512
-            return sum(i.err for i in complete_branches)
479
+            return np.inf
513 480
 
514 481
     def done(self):
515 482
         err = self.err