Browse code

add tests for pickling

Bas Nijholt authored on 10/04/2020 14:23:07
Showing 2 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,103 @@
1
+import random
2
+
3
+import cloudpickle
4
+import pytest
5
+
6
+from adaptive.learner import (
7
+    AverageLearner,
8
+    BalancingLearner,
9
+    DataSaver,
10
+    IntegratorLearner,
11
+    Learner1D,
12
+    Learner2D,
13
+    LearnerND,
14
+    SequenceLearner,
15
+)
16
+from adaptive.runner import simple
17
+
18
+
19
+def goal_1(learner):
20
+    return learner.npoints >= 10
21
+
22
+
23
+def goal_2(learner):
24
+    return learner.npoints >= 20
25
+
26
+
27
+@pytest.mark.parametrize(
28
+    "learner_type, learner_kwargs",
29
+    [
30
+        (Learner1D, dict(bounds=(-1, 1))),
31
+        (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])),
32
+        (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])),
33
+        (SequenceLearner, dict(sequence=list(range(100)))),
34
+        (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)),
35
+        (AverageLearner, dict(atol=0.1)),
36
+    ],
37
+)
38
+def test_cloudpickle_for(learner_type, learner_kwargs):
39
+    """Test serializing a learner using cloudpickle.
40
+
41
+    We use cloudpickle because with pickle the functions are only
42
+    pickled by reference."""
43
+
44
+    def f(x):
45
+        return random.random()
46
+
47
+    learner = learner_type(f, **learner_kwargs)
48
+
49
+    simple(learner, goal_1)
50
+    learner_bytes = cloudpickle.dumps(learner)
51
+
52
+    # Delete references
53
+    del f
54
+    del learner
55
+
56
+    learner_loaded = cloudpickle.loads(learner_bytes)
57
+    assert learner_loaded.npoints >= 10
58
+    simple(learner_loaded, goal_2)
59
+    assert learner_loaded.npoints >= 20
60
+
61
+
62
+def test_cloudpickle_for_datasaver():
63
+    def f(x):
64
+        return dict(x=1, y=x ** 2)
65
+
66
+    _learner = Learner1D(f, bounds=(-1, 1))
67
+    learner = DataSaver(_learner, arg_picker=lambda x: x["y"])
68
+
69
+    simple(learner, goal_1)
70
+    learner_bytes = cloudpickle.dumps(learner)
71
+
72
+    # Delete references
73
+    del f
74
+    del _learner
75
+    del learner
76
+
77
+    learner_loaded = cloudpickle.loads(learner_bytes)
78
+    assert learner_loaded.npoints >= 10
79
+    simple(learner_loaded, goal_2)
80
+    assert learner_loaded.npoints >= 20
81
+
82
+
83
+def test_cloudpickle_for_balancing_learner():
84
+    def f(x):
85
+        return x ** 2
86
+
87
+    learner_1 = Learner1D(f, bounds=(-1, 1))
88
+    learner_2 = Learner1D(f, bounds=(-2, 2))
89
+    learner = BalancingLearner([learner_1, learner_2])
90
+
91
+    simple(learner, goal_1)
92
+    learner_bytes = cloudpickle.dumps(learner)
93
+
94
+    # Delete references
95
+    del f
96
+    del learner_1
97
+    del learner_2
98
+    del learner
99
+
100
+    learner_loaded = cloudpickle.loads(learner_bytes)
101
+    assert learner_loaded.npoints >= 10
102
+    simple(learner_loaded, goal_2)
103
+    assert learner_loaded.npoints >= 20
... ...
@@ -65,4 +65,4 @@ include_trailing_comma=True
65 65
 force_grid_wrap=0
66 66
 use_parentheses=True
67 67
 line_length=88
68
-known_third_party=PIL,atomicwrites,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers
68
+known_third_party=PIL,atomicwrites,cloudpickle,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers