# Based on an adaptive quadrature algorithm by Pedro Gonnet import sys from collections import defaultdict from math import sqrt from operator import attrgetter import numpy as np from scipy.linalg import norm from sortedcontainers import SortedSet from adaptive.learner.base_learner import BaseLearner from adaptive.notebook_integration import ensure_holoviews from adaptive.utils import cache_latest, restore from .integrator_coeffs import ( T_left, T_right, V_inv, Vcond, alpha, b_def, eps, gamma, hint, min_sep, ndiv_max, ns, xi, ) def _downdate(c, nans, depth): # This is algorithm 5 from the thesis of Pedro Gonnet. b = b_def[depth].copy() m = ns[depth] - 1 for i in nans: b[m + 1] /= alpha[m] xii = xi[depth][i] b[m] = (b[m] + xii * b[m + 1]) / alpha[m - 1] for j in range(m - 1, 0, -1): b[j] = (b[j] + xii * b[j + 1] - gamma[j + 1] * b[j + 2]) / alpha[j - 1] b = b[1:] c[:m] -= c[m] / b[m] * b[:m] c[m] = 0 m -= 1 return c def _zero_nans(fx): """Caution: this function modifies fx.""" nans = [] for i in range(len(fx)): if not np.isfinite(fx[i]): nans.append(i) fx[i] = 0.0 return nans def _calc_coeffs(fx, depth): """Caution: this function modifies fx.""" nans = _zero_nans(fx) c_new = V_inv[depth] @ fx if nans: fx[nans] = np.nan c_new = _downdate(c_new, nans, depth) return c_new class DivergentIntegralError(ValueError): pass class _Interval: """ Attributes ---------- (a, b) : (float, float) The left and right boundary of the interval. c : numpy array of shape (4, 33) Coefficients of the fit. depth : int The level of refinement, `depth=0` means that it has 5 (the minimal number of) points and `depth=3` means it has 33 (the maximal number of) points. fx : numpy array of size `(5, 9, 17, 33)[self.depth]`. The function values at the points `self.points(self.depth)`. igral : float The integral value of the interval. err : float The error associated with the integral value. rdepth : int The number of splits that the interval has gone through, starting at 1. ndiv : int A number that is used to determine whether the interval is divergent. parent : _Interval The parent interval. children : list of `_Interval`s The intervals resulting from a split. data : dict A dictionary with the x-values and y-values: `{x1: y1, x2: y2 ...}`. done : bool The integral and the error for the interval has been calculated. done_leaves : set or None Leaves used for the error and the integral estimation of this interval. None means that this information was already propagated to the ancestors of this interval. depth_complete : int or None The level of refinement at which the interval has the integral value evaluated. If None there is no level at which the integral value is known yet. Methods ------- refinement_complete : depth, optional If true, all the function values in the interval are known at `depth`. By default the depth is the depth of the interval. """ __slots__ = [ "a", "b", "c", "c00", "depth", "igral", "err", "fx", "rdepth", "ndiv", "parent", "children", "data", "done_leaves", "depth_complete", "removed", ] def __init__(self, a, b, depth, rdepth): self.children = [] self.data = {} self.a = a self.b = b self.depth = depth self.rdepth = rdepth self.done_leaves = set() self.depth_complete = None self.removed = False @classmethod def make_first(cls, a, b, depth=2): ival = _Interval(a, b, depth, rdepth=1) ival.ndiv = 0 ival.parent = None ival.err = sys.float_info.max # needed because inf/2 == inf return ival @property def T(self): """Get the correct shift matrix. Should only be called on children of a split interval. """ assert self.parent is not None left = self.a == self.parent.a right = self.b == self.parent.b assert left != right return T_left if left else T_right def refinement_complete(self, depth): """The interval has all the y-values to calculate the intergral.""" if len(self.data) < ns[depth]: return False return all(p in self.data for p in self.points(depth)) def points(self, depth=None): if depth is None: depth = self.depth a = self.a b = self.b return (a + b) / 2 + (b - a) * xi[depth] / 2 def refine(self): self.depth += 1 return self def split(self): points = self.points() m = points[len(points) // 2] ivals = [ _Interval(self.a, m, 0, self.rdepth + 1), _Interval(m, self.b, 0, self.rdepth + 1), ] self.children = ivals for ival in ivals: ival.parent = self ival.ndiv = self.ndiv ival.err = self.err / 2 return ivals def calc_igral(self): self.igral = (self.b - self.a) * self.c[0] / sqrt(2) def update_heuristic_err(self, value): """Sets the error of an interval using a heuristic (half the error of the parent) when the actual error cannot be calculated due to its parents not being finished yet. This error is propagated down to its children.""" self.err = value for child in self.children: if child.depth_complete or ( child.depth_complete == 0 and self.depth_complete is not None ): continue child.update_heuristic_err(value / 2) def calc_err(self, c_old): c_new = self.c c_diff = np.zeros(max(len(c_old), len(c_new))) c_diff[: len(c_old)] = c_old c_diff[: len(c_new)] -= c_new c_diff = norm(c_diff) self.err = (self.b - self.a) * c_diff for child in self.children: if child.depth_complete is None: child.update_heuristic_err(self.err / 2) return c_diff def calc_ndiv(self): div = self.parent.c00 and self.c00 / self.parent.c00 > 2 self.ndiv += div if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth: raise DivergentIntegralError if div: for child in self.children: child.update_ndiv_recursively() def update_ndiv_recursively(self): self.ndiv += 1 if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth: raise DivergentIntegralError for child in self.children: child.update_ndiv_recursively() def complete_process(self, depth): """Calculate the integral contribution and error from this interval, and update the done leaves of all ancestor intervals.""" assert self.depth_complete is None or self.depth_complete == depth - 1 self.depth_complete = depth fx = [self.data[k] for k in self.points(depth)] self.fx = np.array(fx) force_split = False # This may change when refining first_ival = self.parent is None and depth == 2 if depth and not first_ival: # Store for usage in refine c_old = self.c self.c = _calc_coeffs(self.fx, depth) if first_ival: self.c00 = 0.0 return False, False self.calc_igral() if depth: # Refine c_diff = self.calc_err(c_old) force_split = c_diff > hint * norm(self.c) else: # Split self.c00 = self.c[0] if self.parent.depth_complete is not None: c_old = self.T[:, : ns[self.parent.depth_complete]] @ self.parent.c self.calc_err(c_old) self.calc_ndiv() for child in self.children: if child.depth_complete is not None: child.calc_ndiv() if child.depth_complete == 0: c_old = child.T[:, : ns[self.depth_complete]] @ self.c child.calc_err(c_old) if self.done_leaves is not None and not len(self.done_leaves): # This interval contributes to the integral estimate. self.done_leaves = {self} # Use this interval in the integral estimates of the ancestors # while possible. ival = self.parent old_leaves = set() while ival is not None: unused_children = [ child for child in ival.children if child.done_leaves is not None ] if not all(len(child.done_leaves) for child in unused_children): break if ival.done_leaves is None: ival.done_leaves = set() old_leaves.add(ival) for child in ival.children: if child.done_leaves is None: continue ival.done_leaves.update(child.done_leaves) child.done_leaves = None ival.done_leaves -= old_leaves ival = ival.parent remove = self.err < (abs(self.igral) * eps * Vcond[depth]) return force_split, remove def __repr__(self): lst = [ f"(a, b)=({self.a:.5f}, {self.b:.5f})", f"depth={self.depth}", f"rdepth={self.rdepth}", f"err={self.err:.5E}", "igral={:.5E}".format(self.igral if hasattr(self, "igral") else np.inf), ] return " ".join(lst) class IntegratorLearner(BaseLearner): def __init__(self, function, bounds, tol): """ Parameters ---------- function : callable: X → Y The function to learn. bounds : pair of reals The bounds of the interval on which to learn 'function'. tol : float Relative tolerance of the error to the integral, this means that the learner is done when: `tol > err / abs(igral)`. Attributes ---------- approximating_intervals : set of intervals The intervals that can be used in the determination of the integral. n : int The total number of evaluated points. igral : float The integral value in `self.bounds`. err : float The absolute error associated with `self.igral`. max_ivals : int, default: 1000 Maximum number of intervals that can be present in the calculation of the integral. If this amount exceeds max_ivals, the interval with the smallest error will be discarded. Methods ------- done : bool Returns whether the `tol` has been reached. plot : hv.Scatter Plots all the points that are evaluated. """ self.function = function self.bounds = bounds self.tol = tol self.max_ivals = 1000 self.priority_split = [] self.data = {} self.pending_points = set() self._stack = [] self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth"))) self.ivals = set() ival = _Interval.make_first(*self.bounds) self.add_ival(ival) self.first_ival = ival @property def approximating_intervals(self): return self.first_ival.done_leaves def tell(self, point, value): if point not in self.x_mapping: raise ValueError(f"Point {point} doesn't belong to any interval") self.data[point] = value self.pending_points.discard(point) # Select the intervals that have this point ivals = self.x_mapping[point] for ival in ivals: ival.data[point] = value if ival.depth_complete is None: from_depth = 0 if ival.parent is not None else 2 else: from_depth = ival.depth_complete + 1 for depth in range(from_depth, ival.depth + 1): if ival.refinement_complete(depth): force_split, remove = ival.complete_process(depth) if remove: # Remove the interval (while remembering the excess # integral and error), since it is either too narrow, # or the estimated relative error is already at the # limit of numerical accuracy and cannot be reduced # further. self.propagate_removed(ival) elif force_split and not ival.children: # If it already has children it has already been split assert ival in self.ivals self.priority_split.append(ival) def tell_pending(self): pass def propagate_removed(self, ival): def _propagate_removed_down(ival): ival.removed = True self.ivals.discard(ival) for child in ival.children: _propagate_removed_down(child) _propagate_removed_down(ival) def add_ival(self, ival): for x in ival.points(): # Update the mappings self.x_mapping[x].add(ival) if x in self.data: self.tell(x, self.data[x]) elif x not in self.pending_points: self.pending_points.add(x) self._stack.append(x) self.ivals.add(ival) def ask(self, n, tell_pending=True): """Choose points for learners.""" if not tell_pending: with restore(self): return self._ask_and_tell_pending(n) else: return self._ask_and_tell_pending(n) def _ask_and_tell_pending(self, n): points, loss_improvements = self.pop_from_stack(n) n_left = n - len(points) while n_left > 0: assert n_left >= 0 try: self._fill_stack() except ValueError: raise RuntimeError("No way to improve the integral estimate.") new_points, new_loss_improvements = self.pop_from_stack(n_left) points += new_points loss_improvements += new_loss_improvements n_left -= len(new_points) return points, loss_improvements def pop_from_stack(self, n): points = self._stack[:n] self._stack = self._stack[n:] loss_improvements = [ max(ival.err for ival in self.x_mapping[x]) for x in points ] return points, loss_improvements def remove_unfinished(self): pass def _fill_stack(self): # XXX: to-do if all the ivals have err=inf, take the interval # with the lowest rdepth and no children. force_split = bool(self.priority_split) if force_split: ival = self.priority_split.pop() else: ival = max(self.ivals, key=lambda x: (x.err, x.a)) assert not ival.children # If the interval points are smaller than machine precision, then # don't continue with splitting or refining. points = ival.points() if ( points[1] - points[0] < points[0] * min_sep or points[-1] - points[-2] < points[-2] * min_sep ): self.ivals.remove(ival) elif ival.depth == 3 or force_split: # Always split when depth is maximal or if refining didn't help self.ivals.remove(ival) for ival in ival.split(): self.add_ival(ival) else: self.add_ival(ival.refine()) # Remove the interval with the smallest error # if number of intervals is larger than max_ivals if len(self.ivals) > self.max_ivals: self.ivals.remove(min(self.ivals, key=lambda x: (x.err, x.a))) return self._stack @property def npoints(self): """Number of evaluated points.""" return len(self.data) @property def igral(self): return sum(i.igral for i in self.approximating_intervals) @property def err(self): if self.approximating_intervals: err = sum(i.err for i in self.approximating_intervals) if err > sys.float_info.max: err = np.inf else: err = np.inf return err def done(self): err = self.err igral = self.igral err_excess = sum(i.err for i in self.approximating_intervals if i.removed) return ( err == 0 or err < abs(igral) * self.tol or (err - err_excess < abs(igral) * self.tol < err_excess) or not self.ivals ) @cache_latest def loss(self, real=True): return abs(abs(self.igral) * self.tol - self.err) def plot(self): hv = ensure_holoviews() ivals = sorted(self.ivals, key=attrgetter("a")) if not self.data: return hv.Path([]) xs, ys = zip(*[(x, y) for ival in ivals for x, y in sorted(ival.data.items())]) return hv.Path((xs, ys)) def _get_data(self): # Change the defaultdict of SortedSets to a normal dict of sets. x_mapping = {k: set(v) for k, v in self.x_mapping.items()} return ( self.priority_split, self.data, self.pending_points, self._stack, x_mapping, self.ivals, self.first_ival, ) def _set_data(self, data): ( self.priority_split, self.data, self.pending_points, self._stack, x_mapping, self.ivals, self.first_ival, ) = data # Add the pending_points to the _stack such that they are evaluated again for x in self.pending_points: if x not in self._stack: self._stack.append(x) # x_mapping is a data structure that can't easily be saved # so we recreate it here self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth"))) for k, _set in x_mapping.items(): self.x_mapping[k].update(_set) def __getstate__(self): return ( self.function, self.bounds, self.tol, self._get_data(), ) def __setstate__(self, state): function, bounds, tol, data = state self.__init__(function, bounds, tol) self._set_data(data)