1 | 1 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,103 @@ |
1 |
+import random |
|
2 |
+ |
|
3 |
+import cloudpickle |
|
4 |
+import pytest |
|
5 |
+ |
|
6 |
+from adaptive.learner import ( |
|
7 |
+ AverageLearner, |
|
8 |
+ BalancingLearner, |
|
9 |
+ DataSaver, |
|
10 |
+ IntegratorLearner, |
|
11 |
+ Learner1D, |
|
12 |
+ Learner2D, |
|
13 |
+ LearnerND, |
|
14 |
+ SequenceLearner, |
|
15 |
+) |
|
16 |
+from adaptive.runner import simple |
|
17 |
+ |
|
18 |
+ |
|
19 |
+def goal_1(learner): |
|
20 |
+ return learner.npoints >= 10 |
|
21 |
+ |
|
22 |
+ |
|
23 |
+def goal_2(learner): |
|
24 |
+ return learner.npoints >= 20 |
|
25 |
+ |
|
26 |
+ |
|
27 |
+@pytest.mark.parametrize( |
|
28 |
+ "learner_type, learner_kwargs", |
|
29 |
+ [ |
|
30 |
+ (Learner1D, dict(bounds=(-1, 1))), |
|
31 |
+ (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])), |
|
32 |
+ (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])), |
|
33 |
+ (SequenceLearner, dict(sequence=list(range(100)))), |
|
34 |
+ (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)), |
|
35 |
+ (AverageLearner, dict(atol=0.1)), |
|
36 |
+ ], |
|
37 |
+) |
|
38 |
+def test_cloudpickle_for(learner_type, learner_kwargs): |
|
39 |
+ """Test serializing a learner using cloudpickle. |
|
40 |
+ |
|
41 |
+ We use cloudpickle because with pickle the functions are only |
|
42 |
+ pickled by reference.""" |
|
43 |
+ |
|
44 |
+ def f(x): |
|
45 |
+ return random.random() |
|
46 |
+ |
|
47 |
+ learner = learner_type(f, **learner_kwargs) |
|
48 |
+ |
|
49 |
+ simple(learner, goal_1) |
|
50 |
+ learner_bytes = cloudpickle.dumps(learner) |
|
51 |
+ |
|
52 |
+ # Delete references |
|
53 |
+ del f |
|
54 |
+ del learner |
|
55 |
+ |
|
56 |
+ learner_loaded = cloudpickle.loads(learner_bytes) |
|
57 |
+ assert learner_loaded.npoints >= 10 |
|
58 |
+ simple(learner_loaded, goal_2) |
|
59 |
+ assert learner_loaded.npoints >= 20 |
|
60 |
+ |
|
61 |
+ |
|
62 |
+def test_cloudpickle_for_datasaver(): |
|
63 |
+ def f(x): |
|
64 |
+ return dict(x=1, y=x ** 2) |
|
65 |
+ |
|
66 |
+ _learner = Learner1D(f, bounds=(-1, 1)) |
|
67 |
+ learner = DataSaver(_learner, arg_picker=lambda x: x["y"]) |
|
68 |
+ |
|
69 |
+ simple(learner, goal_1) |
|
70 |
+ learner_bytes = cloudpickle.dumps(learner) |
|
71 |
+ |
|
72 |
+ # Delete references |
|
73 |
+ del f |
|
74 |
+ del _learner |
|
75 |
+ del learner |
|
76 |
+ |
|
77 |
+ learner_loaded = cloudpickle.loads(learner_bytes) |
|
78 |
+ assert learner_loaded.npoints >= 10 |
|
79 |
+ simple(learner_loaded, goal_2) |
|
80 |
+ assert learner_loaded.npoints >= 20 |
|
81 |
+ |
|
82 |
+ |
|
83 |
+def test_cloudpickle_for_balancing_learner(): |
|
84 |
+ def f(x): |
|
85 |
+ return x ** 2 |
|
86 |
+ |
|
87 |
+ learner_1 = Learner1D(f, bounds=(-1, 1)) |
|
88 |
+ learner_2 = Learner1D(f, bounds=(-2, 2)) |
|
89 |
+ learner = BalancingLearner([learner_1, learner_2]) |
|
90 |
+ |
|
91 |
+ simple(learner, goal_1) |
|
92 |
+ learner_bytes = cloudpickle.dumps(learner) |
|
93 |
+ |
|
94 |
+ # Delete references |
|
95 |
+ del f |
|
96 |
+ del learner_1 |
|
97 |
+ del learner_2 |
|
98 |
+ del learner |
|
99 |
+ |
|
100 |
+ learner_loaded = cloudpickle.loads(learner_bytes) |
|
101 |
+ assert learner_loaded.npoints >= 10 |
|
102 |
+ simple(learner_loaded, goal_2) |
|
103 |
+ assert learner_loaded.npoints >= 20 |
... | ... |
@@ -65,4 +65,4 @@ include_trailing_comma=True |
65 | 65 |
force_grid_wrap=0 |
66 | 66 |
use_parentheses=True |
67 | 67 |
line_length=88 |
68 |
-known_third_party=PIL,atomicwrites,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers |
|
68 |
+known_third_party=PIL,atomicwrites,cloudpickle,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers |