Browse code

move the equal interval check to the tests

Bas Nijholt authored on 29/11/2017 16:51:50
Showing 2 changed files
... ...
@@ -256,25 +256,6 @@ class Interval:
256 256
         ]
257 257
         return ' '.join(lst)
258 258
 
259
-    def equal(self, other, *, verbose=False):
260
-        """Note: Implementing __eq__ breaks SortedContainers in some way."""
261
-        if not self.complete:
262
-            if verbose:
263
-                print('Interval {} is not complete.'.format(self))
264
-            return False
265
-
266
-        slots = set(self.__slots__).intersection(other.__slots__)
267
-        same_slots = []
268
-        for s in slots:
269
-            a = getattr(self, s)
270
-            b = getattr(other, s)
271
-            is_equal = np.allclose(a, b, rtol=0, atol=eps, equal_nan=True)
272
-            if verbose and not is_equal:
273
-                print('self.{} - other.{} = {}'.format(s, s, a - b))
274
-            same_slots.append(is_equal)
275
-
276
-        return all(same_slots)
277
-
278 259
 
279 260
 class IntegratorLearner(BaseLearner):
280 261
 
... ...
@@ -470,17 +451,5 @@ class IntegratorLearner(BaseLearner):
470 451
     def loss(self, real=True):
471 452
         return abs(abs(self.igral) * self.tol - self.err)
472 453
 
473
-    def equal(self, other, *, verbose=False):
474
-        """Note: `other` is a list of ivals."""
475
-        if len(self.ivals) != len(other):
476
-            if verbose:
477
-                print('len(self.ivals)={} != len(other)={}'.format(
478
-                    len(self.ivals), len(other)))
479
-            return False
480
-
481
-        ivals = [sorted(i, key=attrgetter('a')) for i in [self.ivals, other]]
482
-        return all(ival.equal(other_ival, verbose=verbose)
483
-                   for ival, other_ival in zip(*ivals))
484
-
485 454
     def plot(self):
486 455
         return hv.Scatter(self.done_points)
... ...
@@ -1,5 +1,6 @@
1 1
 # -*- coding: utf-8 -*-
2 2
 from functools import partial
3
+from operator import attrgetter
3 4
 
4 5
 import numpy as np
5 6
 import pytest
... ...
@@ -17,6 +18,37 @@ def run_integrator_learner(f, a, b, tol, nr_points):
17 18
     return learner
18 19
 
19 20
 
21
+def equal_ival(ival, other, *, verbose=False):
22
+    """Note: Implementing __eq__ breaks SortedContainers in some way."""
23
+    if not ival.complete:
24
+        if verbose:
25
+            print('Interval {} is not complete.'.format(ival))
26
+        return False
27
+
28
+    slots = set(ival.__slots__).intersection(other.__slots__)
29
+    same_slots = []
30
+    for s in slots:
31
+        a = getattr(ival, s)
32
+        b = getattr(other, s)
33
+        is_equal = np.allclose(a, b, rtol=0, atol=eps, equal_nan=True)
34
+        if verbose and not is_equal:
35
+            print('ival.{} - other.{} = {}'.format(s, s, a - b))
36
+        same_slots.append(is_equal)
37
+
38
+    return all(same_slots)
39
+
40
+def equal_ivals(ivals, other, *, verbose=False):
41
+    """Note: `other` is a list of ivals."""
42
+    if len(ivals) != len(other):
43
+        if verbose:
44
+            print('len(ivals)={} != len(other)={}'.format(
45
+                len(ivals), len(other)))
46
+        return False
47
+
48
+    ivals = [sorted(i, key=attrgetter('a')) for i in [ivals, other]]
49
+    return all(equal_ival(ival, other_ival, verbose=verbose)
50
+               for ival, other_ival in zip(*ivals))
51
+
20 52
 def same_ivals(f, a, b, tol):
21 53
         igral, err, nr_points, ivals = algorithm_4(f, a, b, tol)
22 54
 
... ...
@@ -26,7 +58,7 @@ def same_ivals(f, a, b, tol):
26 58
         print('igral difference', learner.igral-igral,
27 59
               'err difference', learner.err - err)
28 60
 
29
-        return learner.equal(ivals, verbose=True)
61
+        return equal_ivals(learner.ivals, ivals, verbose=True)
30 62
 
31 63
 
32 64
 def test_cquad():
... ...
@@ -47,7 +79,7 @@ def test_machine_precision():
47 79
     print('igral difference', learner.igral-igral,
48 80
           'err difference', learner.err - err)
49 81
 
50
-    assert learner.equal(ivals, verbose=True)
82
+    assert equal_ivals(learner.ivals, ivals, verbose=True)
51 83
 
52 84
 
53 85
 def test_machine_precision2():