Browse code

refactor tests

Bas Nijholt authored on 15/04/2020 23:06:17
Showing 1 changed files
... ...
@@ -1,6 +1,4 @@
1
-import operator
2 1
 import pickle
3
-import random
4 2
 
5 3
 import pytest
6 4
 
... ...
@@ -11,7 +9,6 @@ from adaptive.learner import (
11 9
     IntegratorLearner,
12 10
     Learner1D,
13 11
     Learner2D,
14
-    LearnerND,
15 12
     SequenceLearner,
16 13
 )
17 14
 from adaptive.runner import simple
... ...
@@ -39,49 +36,62 @@ def goal_2(learner):
39 36
     return learner.npoints == 20
40 37
 
41 38
 
39
+def pickleable_f(x):
40
+    return hash(str(x)) / 2 ** 63
41
+
42
+
43
+nonpickleable_f = lambda x: hash(str(x)) / 2 ** 63  # noqa: E731
44
+
45
+
46
+def identity_function(x):
47
+    return x
48
+
49
+
50
+def datasaver(f, learner_type, learner_kwargs):
51
+    return DataSaver(
52
+        learner=learner_type(f, **learner_kwargs), arg_picker=identity_function
53
+    )
54
+
55
+
56
+def balancing_learner(f, learner_type, learner_kwargs):
57
+    learner_1 = learner_type(f, **learner_kwargs)
58
+    learner_2 = learner_type(f, **learner_kwargs)
59
+    return BalancingLearner([learner_1, learner_2])
60
+
61
+
42 62
 learners_pairs = [
43 63
     (Learner1D, dict(bounds=(-1, 1))),
44 64
     (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])),
45
-    (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])),
46 65
     (SequenceLearner, dict(sequence=list(range(100)))),
47 66
     (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)),
48 67
     (AverageLearner, dict(atol=0.1)),
68
+    (datasaver, dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1)))),
69
+    (
70
+        balancing_learner,
71
+        dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1))),
72
+    ),
49 73
 ]
50 74
 
51
-serializers = [pickle]
75
+serializers = [(pickle, pickleable_f)]
52 76
 if with_cloudpickle:
53
-    serializers.append(cloudpickle)
77
+    serializers.append((cloudpickle, nonpickleable_f))
54 78
 if with_dill:
55
-    serializers.append(dill)
79
+    serializers.append((dill, nonpickleable_f))
80
+
56 81
 
57 82
 learners = [
58
-    (learner_type, learner_kwargs, serializer)
59
-    for serializer in serializers
83
+    (learner_type, learner_kwargs, serializer, f)
84
+    for serializer, f in serializers
60 85
     for learner_type, learner_kwargs in learners_pairs
61 86
 ]
62 87
 
63 88
 
64
-def f_for_pickle(x):
65
-    return 1
66
-
67
-
68
-def f_for_pickle_datasaver(x):
69
-    return dict(x=x, y=x)
70
-
71
-
72 89
 @pytest.mark.parametrize(
73
-    "learner_type, learner_kwargs, serializer", learners,
90
+    "learner_type, learner_kwargs, serializer, f", learners,
74 91
 )
75
-def test_serialization_for(learner_type, learner_kwargs, serializer):
92
+def test_serialization_for(learner_type, learner_kwargs, serializer, f):
76 93
     """Test serializing a learner using different serializers."""
77 94
 
78
-    def f(x):
79
-        return random.random()
80
-
81
-    if serializer is pickle:
82
-        # f from the local scope cannot be pickled
83
-        f = f_for_pickle  # noqa: F811
84
-
85 95
     learner = learner_type(f, **learner_kwargs)
86 96
 
87 97
     simple(learner, goal_1)
... ...
@@ -90,10 +100,8 @@ def test_serialization_for(learner_type, learner_kwargs, serializer):
90 100
     asked = learner.ask(10)
91 101
     data = learner.data
92 102
 
93
-    if serializer is not pickle:
94
-        # With pickle the functions are only pickled by reference
95
-        del f
96
-        del learner
103
+    del f
104
+    del learner
97 105
 
98 106
     learner_loaded = serializer.loads(learner_bytes)
99 107
     assert learner_loaded.npoints == 10
... ...
@@ -107,63 +115,3 @@ def test_serialization_for(learner_type, learner_kwargs, serializer):
107 115
 
108 116
     simple(learner_loaded, goal_2)
109 117
     assert learner_loaded.npoints == 20
110
-
111
-
112
-@pytest.mark.parametrize(
113
-    "serializer", serializers,
114
-)
115
-def test_serialization_for_datasaver(serializer):
116
-    def f(x):
117
-        return dict(x=1, y=x ** 2)
118
-
119
-    if serializer is pickle:
120
-        # f from the local scope cannot be pickled
121
-        f = f_for_pickle_datasaver  # noqa: F811
122
-
123
-    _learner = Learner1D(f, bounds=(-1, 1))
124
-    learner = DataSaver(_learner, arg_picker=operator.itemgetter("y"))
125
-
126
-    simple(learner, goal_1)
127
-    learner_bytes = serializer.dumps(learner)
128
-
129
-    if serializer is not pickle:
130
-        # With pickle the functions are only pickled by reference
131
-        del f
132
-        del _learner
133
-        del learner
134
-
135
-    learner_loaded = serializer.loads(learner_bytes)
136
-    assert learner_loaded.npoints == 10
137
-    simple(learner_loaded, goal_2)
138
-    assert learner_loaded.npoints == 20
139
-
140
-
141
-@pytest.mark.parametrize(
142
-    "serializer", serializers,
143
-)
144
-def test_serialization_for_balancing_learner(serializer):
145
-    def f(x):
146
-        return x ** 2
147
-
148
-    if serializer is pickle:
149
-        # f from the local scope cannot be pickled
150
-        f = f_for_pickle  # noqa: F811
151
-
152
-    learner_1 = Learner1D(f, bounds=(-1, 1))
153
-    learner_2 = Learner1D(f, bounds=(-2, 2))
154
-    learner = BalancingLearner([learner_1, learner_2])
155
-
156
-    simple(learner, goal_1)
157
-    learner_bytes = serializer.dumps(learner)
158
-
159
-    if serializer is not pickle:
160
-        # With pickle the functions are only pickled by reference
161
-        del f
162
-        del learner_1
163
-        del learner_2
164
-        del learner
165
-
166
-    learner_loaded = serializer.loads(learner_bytes)
167
-    assert learner_loaded.npoints == 10
168
-    simple(learner_loaded, goal_2)
169
-    assert learner_loaded.npoints == 20