Browse code

only test cloudpickle and dill if installed

Bas Nijholt authored on 10/04/2020 16:22:24
Showing 3 changed files
... ...
@@ -2,8 +2,6 @@ import operator
2 2
 import pickle
3 3
 import random
4 4
 
5
-import cloudpickle
6
-import dill
7 5
 import pytest
8 6
 
9 7
 from adaptive.learner import (
... ...
@@ -18,6 +16,20 @@ from adaptive.learner import (
18 16
 )
19 17
 from adaptive.runner import simple
20 18
 
19
+try:
20
+    import cloudpickle
21
+
22
+    with_cloudpickle = True
23
+except ModuleNotFoundError:
24
+    with_cloudpickle = False
25
+
26
+try:
27
+    import dill
28
+
29
+    with_dill = True
30
+except ModuleNotFoundError:
31
+    with_dill = False
32
+
21 33
 
22 34
 def goal_1(learner):
23 35
     return learner.npoints >= 10
... ...
@@ -36,7 +48,11 @@ learners_pairs = [
36 48
     (AverageLearner, dict(atol=0.1)),
37 49
 ]
38 50
 
39
-serializers = (pickle, dill, cloudpickle)
51
+serializers = [pickle]
52
+if with_cloudpickle:
53
+    serializers.append(cloudpickle)
54
+if with_dill:
55
+    serializers.append(dill)
40 56
 
41 57
 learners = [
42 58
     (learner_type, learner_kwargs, serializer)
... ...
@@ -45,7 +61,7 @@ learners = [
45 61
 ]
46 62
 
47 63
 
48
-def f_for_pickle_balancing_learner(x):
64
+def f_for_pickle(x):
49 65
     return 1
50 66
 
51 67
 
... ...
@@ -62,17 +78,21 @@ def test_serialization_for(learner_type, learner_kwargs, serializer):
62 78
     def f(x):
63 79
         return random.random()
64 80
 
81
+    if serializer is pickle:
82
+        # f from the local scope cannot be pickled
83
+        f = f_for_pickle  # noqa: F811
84
+
65 85
     learner = learner_type(f, **learner_kwargs)
66 86
 
67 87
     simple(learner, goal_1)
68
-    learner_bytes = cloudpickle.dumps(learner)
88
+    learner_bytes = serializer.dumps(learner)
69 89
 
70 90
     if serializer is not pickle:
71 91
         # With pickle the functions are only pickled by reference
72 92
         del f
73 93
         del learner
74 94
 
75
-    learner_loaded = cloudpickle.loads(learner_bytes)
95
+    learner_loaded = serializer.loads(learner_bytes)
76 96
     assert learner_loaded.npoints >= 10
77 97
     simple(learner_loaded, goal_2)
78 98
     assert learner_loaded.npoints >= 20
... ...
@@ -116,7 +136,7 @@ def test_serialization_for_balancing_learner(serializer):
116 136
 
117 137
     if serializer is pickle:
118 138
         # f from the local scope cannot be pickled
119
-        f = f_for_pickle_balancing_learner  # noqa: F811
139
+        f = f_for_pickle  # noqa: F811
120 140
 
121 141
     learner_1 = Learner1D(f, bounds=(-1, 1))
122 142
     learner_2 = Learner1D(f, bounds=(-2, 2))
... ...
@@ -43,8 +43,6 @@ extras_require = {
43 43
         "plotly",
44 44
     ],
45 45
     "testing": [
46
-        "cloudpickle",
47
-        "dill",
48 46
         "flaky",
49 47
         "pytest",
50 48
         "pytest-cov",
... ...
@@ -53,6 +51,8 @@ extras_require = {
53 51
         "pre_commit",
54 52
     ],
55 53
     "other": [
54
+        "cloudpickle",
55
+        "dill",
56 56
         "distributed",
57 57
         "ipyparallel>=6.2.5",  # because of https://github.com/ipython/ipyparallel/issues/404
58 58
         "loky",
... ...
@@ -65,4 +65,4 @@ include_trailing_comma=True
65 65
 force_grid_wrap=0
66 66
 use_parentheses=True
67 67
 line_length=88
68
-known_third_party=PIL,atomicwrites,cloudpickle,dill,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers
68
+known_third_party=PIL,atomicwrites,flaky,holoviews,matplotlib,nbconvert,numpy,pytest,scipy,setuptools,skopt,sortedcollections,sortedcontainers