... | ... |
@@ -1,6 +1,9 @@ |
1 |
+import operator |
|
2 |
+import pickle |
|
1 | 3 |
import random |
2 | 4 |
|
3 | 5 |
import cloudpickle |
6 |
+import dill |
|
4 | 7 |
import pytest |
5 | 8 |
|
6 | 9 |
from adaptive.learner import ( |
... | ... |
@@ -24,22 +27,37 @@ def goal_2(learner): |
24 | 27 |
return learner.npoints >= 20 |
25 | 28 |
|
26 | 29 |
|
30 |
+learners_pairs = [ |
|
31 |
+ (Learner1D, dict(bounds=(-1, 1))), |
|
32 |
+ (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])), |
|
33 |
+ (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])), |
|
34 |
+ (SequenceLearner, dict(sequence=list(range(100)))), |
|
35 |
+ (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)), |
|
36 |
+ (AverageLearner, dict(atol=0.1)), |
|
37 |
+] |
|
38 |
+ |
|
39 |
+serializers = (pickle, dill, cloudpickle) |
|
40 |
+ |
|
41 |
+learners = [ |
|
42 |
+ (learner_type, learner_kwargs, serializer) |
|
43 |
+ for serializer in serializers |
|
44 |
+ for learner_type, learner_kwargs in learners_pairs |
|
45 |
+] |
|
46 |
+ |
|
47 |
+ |
|
48 |
+def f_for_pickle_balancing_learner(x): |
|
49 |
+ return 1 |
|
50 |
+ |
|
51 |
+ |
|
52 |
+def f_for_pickle_datasaver(x): |
|
53 |
+ return dict(x=x, y=x) |
|
54 |
+ |
|
55 |
+ |
|
27 | 56 |
@pytest.mark.parametrize( |
28 |
- "learner_type, learner_kwargs", |
|
29 |
- [ |
|
30 |
- (Learner1D, dict(bounds=(-1, 1))), |
|
31 |
- (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])), |
|
32 |
- (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])), |
|
33 |
- (SequenceLearner, dict(sequence=list(range(100)))), |
|
34 |
- (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)), |
|
35 |
- (AverageLearner, dict(atol=0.1)), |
|
36 |
- ], |
|
57 |
+ "learner_type, learner_kwargs, serializer", learners, |
|
37 | 58 |
) |
38 |
-def test_cloudpickle_for(learner_type, learner_kwargs): |
|
39 |
- """Test serializing a learner using cloudpickle. |
|
40 |
- |
|
41 |
- We use cloudpickle because with pickle the functions are only |
|
42 |
- pickled by reference.""" |
|
59 |
+def test_serialization_for(learner_type, learner_kwargs, serializer): |
|
60 |
+ """Test serializing a learner using different serializers.""" |
|
43 | 61 |
|
44 | 62 |
def f(x): |
45 | 63 |
return random.random() |
... | ... |
@@ -49,9 +67,10 @@ def test_cloudpickle_for(learner_type, learner_kwargs): |
49 | 67 |
simple(learner, goal_1) |
50 | 68 |
learner_bytes = cloudpickle.dumps(learner) |
51 | 69 |
|
52 |
- # Delete references |
|
53 |
- del f |
|
54 |
- del learner |
|
70 |
+ if serializer is not pickle: |
|
71 |
+ # With pickle the functions are only pickled by reference |
|
72 |
+ del f |
|
73 |
+ del learner |
|
55 | 74 |
|
56 | 75 |
learner_loaded = cloudpickle.loads(learner_bytes) |
57 | 76 |
assert learner_loaded.npoints >= 10 |
... | ... |
@@ -59,45 +78,61 @@ def test_cloudpickle_for(learner_type, learner_kwargs): |
59 | 78 |
assert learner_loaded.npoints >= 20 |
60 | 79 |
|
61 | 80 |
|
62 |
-def test_cloudpickle_for_datasaver(): |
|
81 |
+@pytest.mark.parametrize( |
|
82 |
+ "serializer", serializers, |
|
83 |
+) |
|
84 |
+def test_serialization_for_datasaver(serializer): |
|
63 | 85 |
def f(x): |
64 | 86 |
return dict(x=1, y=x ** 2) |
65 | 87 |
|
88 |
+ if serializer is pickle: |
|
89 |
+ # f from the local scope cannot be pickled |
|
90 |
+ f = f_for_pickle_datasaver # noqa: F811 |
|
91 |
+ |
|
66 | 92 |
_learner = Learner1D(f, bounds=(-1, 1)) |
67 |
- learner = DataSaver(_learner, arg_picker=lambda x: x["y"]) |
|
93 |
+ learner = DataSaver(_learner, arg_picker=operator.itemgetter("y")) |
|
68 | 94 |
|
69 | 95 |
simple(learner, goal_1) |
70 |
- learner_bytes = cloudpickle.dumps(learner) |
|
96 |
+ learner_bytes = serializer.dumps(learner) |
|
71 | 97 |
|
72 |
- # Delete references |
|
73 |
- del f |
|
74 |
- del _learner |
|
75 |
- del learner |
|
98 |
+ if serializer is not pickle: |
|
99 |
+ # With pickle the functions are only pickled by reference |
|
100 |
+ del f |
|
101 |
+ del _learner |
|
102 |
+ del learner |
|
76 | 103 |
|
77 |
- learner_loaded = cloudpickle.loads(learner_bytes) |
|
104 |
+ learner_loaded = serializer.loads(learner_bytes) |
|
78 | 105 |
assert learner_loaded.npoints >= 10 |
79 | 106 |
simple(learner_loaded, goal_2) |
80 | 107 |
assert learner_loaded.npoints >= 20 |
81 | 108 |
|
82 | 109 |
|
83 |
-def test_cloudpickle_for_balancing_learner(): |
|
110 |
+@pytest.mark.parametrize( |
|
111 |
+ "serializer", serializers, |
|
112 |
+) |
|
113 |
+def test_serialization_for_balancing_learner(serializer): |
|
84 | 114 |
def f(x): |
85 | 115 |
return x ** 2 |
86 | 116 |
|
117 |
+ if serializer is pickle: |
|
118 |
+ # f from the local scope cannot be pickled |
|
119 |
+ f = f_for_pickle_balancing_learner # noqa: F811 |
|
120 |
+ |
|
87 | 121 |
learner_1 = Learner1D(f, bounds=(-1, 1)) |
88 | 122 |
learner_2 = Learner1D(f, bounds=(-2, 2)) |
89 | 123 |
learner = BalancingLearner([learner_1, learner_2]) |
90 | 124 |
|
91 | 125 |
simple(learner, goal_1) |
92 |
- learner_bytes = cloudpickle.dumps(learner) |
|
126 |
+ learner_bytes = serializer.dumps(learner) |
|
93 | 127 |
|
94 |
- # Delete references |
|
95 |
- del f |
|
96 |
- del learner_1 |
|
97 |
- del learner_2 |
|
98 |
- del learner |
|
128 |
+ if serializer is not pickle: |
|
129 |
+ # With pickle the functions are only pickled by reference |
|
130 |
+ del f |
|
131 |
+ del learner_1 |
|
132 |
+ del learner_2 |
|
133 |
+ del learner |
|
99 | 134 |
|
100 |
- learner_loaded = cloudpickle.loads(learner_bytes) |
|
135 |
+ learner_loaded = serializer.loads(learner_bytes) |
|
101 | 136 |
assert learner_loaded.npoints >= 10 |
102 | 137 |
simple(learner_loaded, goal_2) |
103 | 138 |
assert learner_loaded.npoints >= 20 |
... | ... |
@@ -65,4 +65,4 @@ include_trailing_comma=True |
65 | 65 |
force_grid_wrap=0 |
66 | 66 |
use_parentheses=True |
67 | 67 |
line_length=88 |
68 |
-known_third_party=PIL,atomicwrites,cloudpickle,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers |
|
68 |
+known_third_party=PIL,atomicwrites,cloudpickle,dill,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers |