... | ... |
@@ -1,6 +1,4 @@ |
1 |
-import operator |
|
2 | 1 |
import pickle |
3 |
-import random |
|
4 | 2 |
|
5 | 3 |
import pytest |
6 | 4 |
|
... | ... |
@@ -11,7 +9,6 @@ from adaptive.learner import ( |
11 | 9 |
IntegratorLearner, |
12 | 10 |
Learner1D, |
13 | 11 |
Learner2D, |
14 |
- LearnerND, |
|
15 | 12 |
SequenceLearner, |
16 | 13 |
) |
17 | 14 |
from adaptive.runner import simple |
... | ... |
@@ -39,49 +36,62 @@ def goal_2(learner): |
39 | 36 |
return learner.npoints == 20 |
40 | 37 |
|
41 | 38 |
|
39 |
+def pickleable_f(x): |
|
40 |
+ return hash(str(x)) / 2 ** 63 |
|
41 |
+ |
|
42 |
+ |
|
43 |
+nonpickleable_f = lambda x: hash(str(x)) / 2 ** 63 # noqa: E731 |
|
44 |
+ |
|
45 |
+ |
|
46 |
+def identity_function(x): |
|
47 |
+ return x |
|
48 |
+ |
|
49 |
+ |
|
50 |
+def datasaver(f, learner_type, learner_kwargs): |
|
51 |
+ return DataSaver( |
|
52 |
+ learner=learner_type(f, **learner_kwargs), arg_picker=identity_function |
|
53 |
+ ) |
|
54 |
+ |
|
55 |
+ |
|
56 |
+def balancing_learner(f, learner_type, learner_kwargs): |
|
57 |
+ learner_1 = learner_type(f, **learner_kwargs) |
|
58 |
+ learner_2 = learner_type(f, **learner_kwargs) |
|
59 |
+ return BalancingLearner([learner_1, learner_2]) |
|
60 |
+ |
|
61 |
+ |
|
42 | 62 |
learners_pairs = [ |
43 | 63 |
(Learner1D, dict(bounds=(-1, 1))), |
44 | 64 |
(Learner2D, dict(bounds=[(-1, 1), (-1, 1)])), |
45 |
- (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])), |
|
46 | 65 |
(SequenceLearner, dict(sequence=list(range(100)))), |
47 | 66 |
(IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)), |
48 | 67 |
(AverageLearner, dict(atol=0.1)), |
68 |
+ (datasaver, dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1)))), |
|
69 |
+ ( |
|
70 |
+ balancing_learner, |
|
71 |
+ dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1))), |
|
72 |
+ ), |
|
49 | 73 |
] |
50 | 74 |
|
51 |
-serializers = [pickle] |
|
75 |
+serializers = [(pickle, pickleable_f)] |
|
52 | 76 |
if with_cloudpickle: |
53 |
- serializers.append(cloudpickle) |
|
77 |
+ serializers.append((cloudpickle, nonpickleable_f)) |
|
54 | 78 |
if with_dill: |
55 |
- serializers.append(dill) |
|
79 |
+ serializers.append((dill, nonpickleable_f)) |
|
80 |
+ |
|
56 | 81 |
|
57 | 82 |
learners = [ |
58 |
- (learner_type, learner_kwargs, serializer) |
|
59 |
- for serializer in serializers |
|
83 |
+ (learner_type, learner_kwargs, serializer, f) |
|
84 |
+ for serializer, f in serializers |
|
60 | 85 |
for learner_type, learner_kwargs in learners_pairs |
61 | 86 |
] |
62 | 87 |
|
63 | 88 |
|
64 |
-def f_for_pickle(x): |
|
65 |
- return 1 |
|
66 |
- |
|
67 |
- |
|
68 |
-def f_for_pickle_datasaver(x): |
|
69 |
- return dict(x=x, y=x) |
|
70 |
- |
|
71 |
- |
|
72 | 89 |
@pytest.mark.parametrize( |
73 |
- "learner_type, learner_kwargs, serializer", learners, |
|
90 |
+ "learner_type, learner_kwargs, serializer, f", learners, |
|
74 | 91 |
) |
75 |
-def test_serialization_for(learner_type, learner_kwargs, serializer): |
|
92 |
+def test_serialization_for(learner_type, learner_kwargs, serializer, f): |
|
76 | 93 |
"""Test serializing a learner using different serializers.""" |
77 | 94 |
|
78 |
- def f(x): |
|
79 |
- return random.random() |
|
80 |
- |
|
81 |
- if serializer is pickle: |
|
82 |
- # f from the local scope cannot be pickled |
|
83 |
- f = f_for_pickle # noqa: F811 |
|
84 |
- |
|
85 | 95 |
learner = learner_type(f, **learner_kwargs) |
86 | 96 |
|
87 | 97 |
simple(learner, goal_1) |
... | ... |
@@ -90,10 +100,8 @@ def test_serialization_for(learner_type, learner_kwargs, serializer): |
90 | 100 |
asked = learner.ask(10) |
91 | 101 |
data = learner.data |
92 | 102 |
|
93 |
- if serializer is not pickle: |
|
94 |
- # With pickle the functions are only pickled by reference |
|
95 |
- del f |
|
96 |
- del learner |
|
103 |
+ del f |
|
104 |
+ del learner |
|
97 | 105 |
|
98 | 106 |
learner_loaded = serializer.loads(learner_bytes) |
99 | 107 |
assert learner_loaded.npoints == 10 |
... | ... |
@@ -107,63 +115,3 @@ def test_serialization_for(learner_type, learner_kwargs, serializer): |
107 | 115 |
|
108 | 116 |
simple(learner_loaded, goal_2) |
109 | 117 |
assert learner_loaded.npoints == 20 |
110 |
- |
|
111 |
- |
|
112 |
-@pytest.mark.parametrize( |
|
113 |
- "serializer", serializers, |
|
114 |
-) |
|
115 |
-def test_serialization_for_datasaver(serializer): |
|
116 |
- def f(x): |
|
117 |
- return dict(x=1, y=x ** 2) |
|
118 |
- |
|
119 |
- if serializer is pickle: |
|
120 |
- # f from the local scope cannot be pickled |
|
121 |
- f = f_for_pickle_datasaver # noqa: F811 |
|
122 |
- |
|
123 |
- _learner = Learner1D(f, bounds=(-1, 1)) |
|
124 |
- learner = DataSaver(_learner, arg_picker=operator.itemgetter("y")) |
|
125 |
- |
|
126 |
- simple(learner, goal_1) |
|
127 |
- learner_bytes = serializer.dumps(learner) |
|
128 |
- |
|
129 |
- if serializer is not pickle: |
|
130 |
- # With pickle the functions are only pickled by reference |
|
131 |
- del f |
|
132 |
- del _learner |
|
133 |
- del learner |
|
134 |
- |
|
135 |
- learner_loaded = serializer.loads(learner_bytes) |
|
136 |
- assert learner_loaded.npoints == 10 |
|
137 |
- simple(learner_loaded, goal_2) |
|
138 |
- assert learner_loaded.npoints == 20 |
|
139 |
- |
|
140 |
- |
|
141 |
-@pytest.mark.parametrize( |
|
142 |
- "serializer", serializers, |
|
143 |
-) |
|
144 |
-def test_serialization_for_balancing_learner(serializer): |
|
145 |
- def f(x): |
|
146 |
- return x ** 2 |
|
147 |
- |
|
148 |
- if serializer is pickle: |
|
149 |
- # f from the local scope cannot be pickled |
|
150 |
- f = f_for_pickle # noqa: F811 |
|
151 |
- |
|
152 |
- learner_1 = Learner1D(f, bounds=(-1, 1)) |
|
153 |
- learner_2 = Learner1D(f, bounds=(-2, 2)) |
|
154 |
- learner = BalancingLearner([learner_1, learner_2]) |
|
155 |
- |
|
156 |
- simple(learner, goal_1) |
|
157 |
- learner_bytes = serializer.dumps(learner) |
|
158 |
- |
|
159 |
- if serializer is not pickle: |
|
160 |
- # With pickle the functions are only pickled by reference |
|
161 |
- del f |
|
162 |
- del learner_1 |
|
163 |
- del learner_2 |
|
164 |
- del learner |
|
165 |
- |
|
166 |
- learner_loaded = serializer.loads(learner_bytes) |
|
167 |
- assert learner_loaded.npoints == 10 |
|
168 |
- simple(learner_loaded, goal_2) |
|
169 |
- assert learner_loaded.npoints == 20 |