Browse code

test serialization with pickle, cloudpickle, and dill

Bas Nijholt authored on 10/04/2020 15:18:14
Showing 3 changed files
... ...
@@ -1,6 +1,9 @@
1
+import operator
2
+import pickle
1 3
 import random
2 4
 
3 5
 import cloudpickle
6
+import dill
4 7
 import pytest
5 8
 
6 9
 from adaptive.learner import (
... ...
@@ -24,22 +27,37 @@ def goal_2(learner):
24 27
     return learner.npoints >= 20
25 28
 
26 29
 
30
+learners_pairs = [
31
+    (Learner1D, dict(bounds=(-1, 1))),
32
+    (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])),
33
+    (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])),
34
+    (SequenceLearner, dict(sequence=list(range(100)))),
35
+    (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)),
36
+    (AverageLearner, dict(atol=0.1)),
37
+]
38
+
39
+serializers = (pickle, dill, cloudpickle)
40
+
41
+learners = [
42
+    (learner_type, learner_kwargs, serializer)
43
+    for serializer in serializers
44
+    for learner_type, learner_kwargs in learners_pairs
45
+]
46
+
47
+
48
+def f_for_pickle_balancing_learner(x):
49
+    return 1
50
+
51
+
52
+def f_for_pickle_datasaver(x):
53
+    return dict(x=x, y=x)
54
+
55
+
27 56
 @pytest.mark.parametrize(
28
-    "learner_type, learner_kwargs",
29
-    [
30
-        (Learner1D, dict(bounds=(-1, 1))),
31
-        (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])),
32
-        (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])),
33
-        (SequenceLearner, dict(sequence=list(range(100)))),
34
-        (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)),
35
-        (AverageLearner, dict(atol=0.1)),
36
-    ],
57
+    "learner_type, learner_kwargs, serializer", learners,
37 58
 )
38
-def test_cloudpickle_for(learner_type, learner_kwargs):
39
-    """Test serializing a learner using cloudpickle.
40
-
41
-    We use cloudpickle because with pickle the functions are only
42
-    pickled by reference."""
59
+def test_serialization_for(learner_type, learner_kwargs, serializer):
60
+    """Test serializing a learner using different serializers."""
43 61
 
44 62
     def f(x):
45 63
         return random.random()
... ...
@@ -49,9 +67,10 @@ def test_cloudpickle_for(learner_type, learner_kwargs):
49 67
     simple(learner, goal_1)
50 68
     learner_bytes = cloudpickle.dumps(learner)
51 69
 
52
-    # Delete references
53
-    del f
54
-    del learner
70
+    if serializer is not pickle:
71
+        # With pickle the functions are only pickled by reference
72
+        del f
73
+        del learner
55 74
 
56 75
     learner_loaded = cloudpickle.loads(learner_bytes)
57 76
     assert learner_loaded.npoints >= 10
... ...
@@ -59,45 +78,61 @@ def test_cloudpickle_for(learner_type, learner_kwargs):
59 78
     assert learner_loaded.npoints >= 20
60 79
 
61 80
 
62
-def test_cloudpickle_for_datasaver():
81
+@pytest.mark.parametrize(
82
+    "serializer", serializers,
83
+)
84
+def test_serialization_for_datasaver(serializer):
63 85
     def f(x):
64 86
         return dict(x=1, y=x ** 2)
65 87
 
88
+    if serializer is pickle:
89
+        # f from the local scope cannot be pickled
90
+        f = f_for_pickle_datasaver  # noqa: F811
91
+
66 92
     _learner = Learner1D(f, bounds=(-1, 1))
67
-    learner = DataSaver(_learner, arg_picker=lambda x: x["y"])
93
+    learner = DataSaver(_learner, arg_picker=operator.itemgetter("y"))
68 94
 
69 95
     simple(learner, goal_1)
70
-    learner_bytes = cloudpickle.dumps(learner)
96
+    learner_bytes = serializer.dumps(learner)
71 97
 
72
-    # Delete references
73
-    del f
74
-    del _learner
75
-    del learner
98
+    if serializer is not pickle:
99
+        # With pickle the functions are only pickled by reference
100
+        del f
101
+        del _learner
102
+        del learner
76 103
 
77
-    learner_loaded = cloudpickle.loads(learner_bytes)
104
+    learner_loaded = serializer.loads(learner_bytes)
78 105
     assert learner_loaded.npoints >= 10
79 106
     simple(learner_loaded, goal_2)
80 107
     assert learner_loaded.npoints >= 20
81 108
 
82 109
 
83
-def test_cloudpickle_for_balancing_learner():
110
+@pytest.mark.parametrize(
111
+    "serializer", serializers,
112
+)
113
+def test_serialization_for_balancing_learner(serializer):
84 114
     def f(x):
85 115
         return x ** 2
86 116
 
117
+    if serializer is pickle:
118
+        # f from the local scope cannot be pickled
119
+        f = f_for_pickle_balancing_learner  # noqa: F811
120
+
87 121
     learner_1 = Learner1D(f, bounds=(-1, 1))
88 122
     learner_2 = Learner1D(f, bounds=(-2, 2))
89 123
     learner = BalancingLearner([learner_1, learner_2])
90 124
 
91 125
     simple(learner, goal_1)
92
-    learner_bytes = cloudpickle.dumps(learner)
126
+    learner_bytes = serializer.dumps(learner)
93 127
 
94
-    # Delete references
95
-    del f
96
-    del learner_1
97
-    del learner_2
98
-    del learner
128
+    if serializer is not pickle:
129
+        # With pickle the functions are only pickled by reference
130
+        del f
131
+        del learner_1
132
+        del learner_2
133
+        del learner
99 134
 
100
-    learner_loaded = cloudpickle.loads(learner_bytes)
135
+    learner_loaded = serializer.loads(learner_bytes)
101 136
     assert learner_loaded.npoints >= 10
102 137
     simple(learner_loaded, goal_2)
103 138
     assert learner_loaded.npoints >= 20
... ...
@@ -44,6 +44,7 @@ extras_require = {
44 44
     ],
45 45
     "testing": [
46 46
         "cloudpickle",
47
+        "dill",
47 48
         "flaky",
48 49
         "pytest",
49 50
         "pytest-cov",
... ...
@@ -65,4 +65,4 @@ include_trailing_comma=True
65 65
 force_grid_wrap=0
66 66
 use_parentheses=True
67 67
 line_length=88
68
-known_third_party=PIL,atomicwrites,cloudpickle,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers
68
+known_third_party=PIL,atomicwrites,cloudpickle,dill,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers