... | ... |
@@ -256,25 +256,6 @@ class Interval: |
256 | 256 |
] |
257 | 257 |
return ' '.join(lst) |
258 | 258 |
|
259 |
- def equal(self, other, *, verbose=False): |
|
260 |
- """Note: Implementing __eq__ breaks SortedContainers in some way.""" |
|
261 |
- if not self.complete: |
|
262 |
- if verbose: |
|
263 |
- print('Interval {} is not complete.'.format(self)) |
|
264 |
- return False |
|
265 |
- |
|
266 |
- slots = set(self.__slots__).intersection(other.__slots__) |
|
267 |
- same_slots = [] |
|
268 |
- for s in slots: |
|
269 |
- a = getattr(self, s) |
|
270 |
- b = getattr(other, s) |
|
271 |
- is_equal = np.allclose(a, b, rtol=0, atol=eps, equal_nan=True) |
|
272 |
- if verbose and not is_equal: |
|
273 |
- print('self.{} - other.{} = {}'.format(s, s, a - b)) |
|
274 |
- same_slots.append(is_equal) |
|
275 |
- |
|
276 |
- return all(same_slots) |
|
277 |
- |
|
278 | 259 |
|
279 | 260 |
class IntegratorLearner(BaseLearner): |
280 | 261 |
|
... | ... |
@@ -470,17 +451,5 @@ class IntegratorLearner(BaseLearner): |
470 | 451 |
def loss(self, real=True): |
471 | 452 |
return abs(abs(self.igral) * self.tol - self.err) |
472 | 453 |
|
473 |
- def equal(self, other, *, verbose=False): |
|
474 |
- """Note: `other` is a list of ivals.""" |
|
475 |
- if len(self.ivals) != len(other): |
|
476 |
- if verbose: |
|
477 |
- print('len(self.ivals)={} != len(other)={}'.format( |
|
478 |
- len(self.ivals), len(other))) |
|
479 |
- return False |
|
480 |
- |
|
481 |
- ivals = [sorted(i, key=attrgetter('a')) for i in [self.ivals, other]] |
|
482 |
- return all(ival.equal(other_ival, verbose=verbose) |
|
483 |
- for ival, other_ival in zip(*ivals)) |
|
484 |
- |
|
485 | 454 |
def plot(self): |
486 | 455 |
return hv.Scatter(self.done_points) |
... | ... |
@@ -1,5 +1,6 @@ |
1 | 1 |
# -*- coding: utf-8 -*- |
2 | 2 |
from functools import partial |
3 |
+from operator import attrgetter |
|
3 | 4 |
|
4 | 5 |
import numpy as np |
5 | 6 |
import pytest |
... | ... |
@@ -17,6 +18,37 @@ def run_integrator_learner(f, a, b, tol, nr_points): |
17 | 18 |
return learner |
18 | 19 |
|
19 | 20 |
|
21 |
+def equal_ival(ival, other, *, verbose=False): |
|
22 |
+ """Note: Implementing __eq__ breaks SortedContainers in some way.""" |
|
23 |
+ if not ival.complete: |
|
24 |
+ if verbose: |
|
25 |
+ print('Interval {} is not complete.'.format(ival)) |
|
26 |
+ return False |
|
27 |
+ |
|
28 |
+ slots = set(ival.__slots__).intersection(other.__slots__) |
|
29 |
+ same_slots = [] |
|
30 |
+ for s in slots: |
|
31 |
+ a = getattr(ival, s) |
|
32 |
+ b = getattr(other, s) |
|
33 |
+ is_equal = np.allclose(a, b, rtol=0, atol=eps, equal_nan=True) |
|
34 |
+ if verbose and not is_equal: |
|
35 |
+ print('ival.{} - other.{} = {}'.format(s, s, a - b)) |
|
36 |
+ same_slots.append(is_equal) |
|
37 |
+ |
|
38 |
+ return all(same_slots) |
|
39 |
+ |
|
40 |
+def equal_ivals(ivals, other, *, verbose=False): |
|
41 |
+ """Note: `other` is a list of ivals.""" |
|
42 |
+ if len(ivals) != len(other): |
|
43 |
+ if verbose: |
|
44 |
+ print('len(ivals)={} != len(other)={}'.format( |
|
45 |
+ len(ivals), len(other))) |
|
46 |
+ return False |
|
47 |
+ |
|
48 |
+ ivals = [sorted(i, key=attrgetter('a')) for i in [ivals, other]] |
|
49 |
+ return all(equal_ival(ival, other_ival, verbose=verbose) |
|
50 |
+ for ival, other_ival in zip(*ivals)) |
|
51 |
+ |
|
20 | 52 |
def same_ivals(f, a, b, tol): |
21 | 53 |
igral, err, nr_points, ivals = algorithm_4(f, a, b, tol) |
22 | 54 |
|
... | ... |
@@ -26,7 +58,7 @@ def same_ivals(f, a, b, tol): |
26 | 58 |
print('igral difference', learner.igral-igral, |
27 | 59 |
'err difference', learner.err - err) |
28 | 60 |
|
29 |
- return learner.equal(ivals, verbose=True) |
|
61 |
+ return equal_ivals(learner.ivals, ivals, verbose=True) |
|
30 | 62 |
|
31 | 63 |
|
32 | 64 |
def test_cquad(): |
... | ... |
@@ -47,7 +79,7 @@ def test_machine_precision(): |
47 | 79 |
print('igral difference', learner.igral-igral, |
48 | 80 |
'err difference', learner.err - err) |
49 | 81 |
|
50 |
- assert learner.equal(ivals, verbose=True) |
|
82 |
+ assert equal_ivals(learner.ivals, ivals, verbose=True) |
|
51 | 83 |
|
52 | 84 |
|
53 | 85 |
def test_machine_precision2(): |