Browse code

Merge pull request #264 from python-adaptive/pickle

make learners picklable

Bas Nijholt authored on 24/04/2020 18:03:45 • GitHub committed on 24/04/2020 18:03:45
Showing 9 changed files
... ...
@@ -144,3 +144,16 @@ class AverageLearner(BaseLearner):
144 144
 
145 145
     def _set_data(self, data):
146 146
         self.data, self.npoints, self.sum_f, self.sum_f_sq = data
147
+
148
+    def __getstate__(self):
149
+        return (
150
+            self.function,
151
+            self.atol,
152
+            self.rtol,
153
+            self._get_data(),
154
+        )
155
+
156
+    def __setstate__(self, state):
157
+        function, atol, rtol, data = state
158
+        self.__init__(function, atol, rtol)
159
+        self._set_data(data)
... ...
@@ -440,3 +440,14 @@ class BalancingLearner(BaseLearner):
440 440
     def _set_data(self, data):
441 441
         for l, _data in zip(self.learners, data):
442 442
             l._set_data(_data)
443
+
444
+    def __getstate__(self):
445
+        return (
446
+            self.learners,
447
+            self._cdims_default,
448
+            self.strategy,
449
+        )
450
+
451
+    def __setstate__(self, state):
452
+        learners, cdims, strategy = state
453
+        self.__init__(learners, cdims=cdims, strategy=strategy)
... ...
@@ -51,6 +51,18 @@ class DataSaver:
51 51
         learner_data, self.extra_data = data
52 52
         self.learner._set_data(learner_data)
53 53
 
54
+    def __getstate__(self):
55
+        return (
56
+            self.learner,
57
+            self.arg_picker,
58
+            self.extra_data,
59
+        )
60
+
61
+    def __setstate__(self, state):
62
+        learner, arg_picker, extra_data = state
63
+        self.__init__(learner, arg_picker)
64
+        self.extra_data = extra_data
65
+
54 66
     @copy_docstring_from(BaseLearner.save)
55 67
     def save(self, fname, compress=True):
56 68
         # We copy this method because the 'DataSaver' is not a
... ...
@@ -591,3 +591,16 @@ class IntegratorLearner(BaseLearner):
591 591
         self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
592 592
         for k, _set in x_mapping.items():
593 593
             self.x_mapping[k].update(_set)
594
+
595
+    def __getstate__(self):
596
+        return (
597
+            self.function,
598
+            self.bounds,
599
+            self.tol,
600
+            self._get_data(),
601
+        )
602
+
603
+    def __setstate__(self, state):
604
+        function, bounds, tol, data = state
605
+        self.__init__(function, bounds, tol)
606
+        self._set_data(data)
... ...
@@ -625,6 +625,23 @@ class Learner1D(BaseLearner):
625 625
         if data:
626 626
             self.tell_many(*zip(*data.items()))
627 627
 
628
+    def __getstate__(self):
629
+        return (
630
+            self.function,
631
+            self.bounds,
632
+            self.loss_per_interval,
633
+            dict(self.losses),  # SortedDict cannot be pickled
634
+            dict(self.losses_combined),  # ItemSortedDict cannot be pickled
635
+            self._get_data(),
636
+        )
637
+
638
+    def __setstate__(self, state):
639
+        function, bounds, loss_per_interval, losses, losses_combined, data = state
640
+        self.__init__(function, bounds, loss_per_interval)
641
+        self._set_data(data)
642
+        self.losses.update(losses)
643
+        self.losses_combined.update(losses_combined)
644
+
628 645
 
629 646
 def loss_manager(x_scale):
630 647
     def sort_key(ival, loss):
... ...
@@ -706,3 +706,18 @@ class Learner2D(BaseLearner):
706 706
         for point in copy(self._stack):
707 707
             if point in self.data:
708 708
                 self._stack.pop(point)
709
+
710
+    def __getstate__(self):
711
+        return (
712
+            self.function,
713
+            self.bounds,
714
+            self.loss_per_triangle,
715
+            self._stack,
716
+            self._get_data(),
717
+        )
718
+
719
+    def __setstate__(self, state):
720
+        function, bounds, loss_per_triangle, _stack, data = state
721
+        self.__init__(function, bounds, loss_per_triangle)
722
+        self._set_data(data)
723
+        self._stack = _stack
... ...
@@ -83,16 +83,6 @@ class SequenceLearner(BaseLearner):
83 83
 
84 84
         return points, loss_improvements
85 85
 
86
-    def _get_data(self):
87
-        return self.data
88
-
89
-    def _set_data(self, data):
90
-        if data:
91
-            indices, values = zip(*data.items())
92
-            # the points aren't used by tell, so we can safely pass None
93
-            points = [(i, None) for i in indices]
94
-            self.tell_many(points, values)
95
-
96 86
     def loss(self, real=True):
97 87
         if not (self._to_do_indices or self.pending_points):
98 88
             return 0
... ...
@@ -128,3 +118,25 @@ class SequenceLearner(BaseLearner):
128 118
     @property
129 119
     def npoints(self):
130 120
         return len(self.data)
121
+
122
+    def _get_data(self):
123
+        return self.data
124
+
125
+    def _set_data(self, data):
126
+        if data:
127
+            indices, values = zip(*data.items())
128
+            # the points aren't used by tell, so we can safely pass None
129
+            points = [(i, None) for i in indices]
130
+            self.tell_many(points, values)
131
+
132
+    def __getstate__(self):
133
+        return (
134
+            self._original_function,
135
+            self.sequence,
136
+            self._get_data(),
137
+        )
138
+
139
+    def __setstate__(self, state):
140
+        function, sequence, data = state
141
+        self.__init__(function, sequence)
142
+        self._set_data(data)
131 143
new file mode 100644
... ...
@@ -0,0 +1,117 @@
1
+import pickle
2
+
3
+import pytest
4
+
5
+from adaptive.learner import (
6
+    AverageLearner,
7
+    BalancingLearner,
8
+    DataSaver,
9
+    IntegratorLearner,
10
+    Learner1D,
11
+    Learner2D,
12
+    SequenceLearner,
13
+)
14
+from adaptive.runner import simple
15
+
16
+try:
17
+    import cloudpickle
18
+
19
+    with_cloudpickle = True
20
+except ModuleNotFoundError:
21
+    with_cloudpickle = False
22
+
23
+try:
24
+    import dill
25
+
26
+    with_dill = True
27
+except ModuleNotFoundError:
28
+    with_dill = False
29
+
30
+
31
+def goal_1(learner):
32
+    return learner.npoints == 10
33
+
34
+
35
+def goal_2(learner):
36
+    return learner.npoints == 20
37
+
38
+
39
+def pickleable_f(x):
40
+    return hash(str(x)) / 2 ** 63
41
+
42
+
43
+nonpickleable_f = lambda x: hash(str(x)) / 2 ** 63  # noqa: E731
44
+
45
+
46
+def identity_function(x):
47
+    return x
48
+
49
+
50
+def datasaver(f, learner_type, learner_kwargs):
51
+    return DataSaver(
52
+        learner=learner_type(f, **learner_kwargs), arg_picker=identity_function
53
+    )
54
+
55
+
56
+def balancing_learner(f, learner_type, learner_kwargs):
57
+    learner_1 = learner_type(f, **learner_kwargs)
58
+    learner_2 = learner_type(f, **learner_kwargs)
59
+    return BalancingLearner([learner_1, learner_2])
60
+
61
+
62
+learners_pairs = [
63
+    (Learner1D, dict(bounds=(-1, 1))),
64
+    (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])),
65
+    (SequenceLearner, dict(sequence=list(range(100)))),
66
+    (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)),
67
+    (AverageLearner, dict(atol=0.1)),
68
+    (datasaver, dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1)))),
69
+    (
70
+        balancing_learner,
71
+        dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1))),
72
+    ),
73
+]
74
+
75
+serializers = [(pickle, pickleable_f)]
76
+if with_cloudpickle:
77
+    serializers.append((cloudpickle, nonpickleable_f))
78
+if with_dill:
79
+    serializers.append((dill, nonpickleable_f))
80
+
81
+
82
+learners = [
83
+    (learner_type, learner_kwargs, serializer, f)
84
+    for serializer, f in serializers
85
+    for learner_type, learner_kwargs in learners_pairs
86
+]
87
+
88
+
89
+@pytest.mark.parametrize(
90
+    "learner_type, learner_kwargs, serializer, f", learners,
91
+)
92
+def test_serialization_for(learner_type, learner_kwargs, serializer, f):
93
+    """Test serializing a learner using different serializers."""
94
+
95
+    learner = learner_type(f, **learner_kwargs)
96
+
97
+    simple(learner, goal_1)
98
+    learner_bytes = serializer.dumps(learner)
99
+    loss = learner.loss()
100
+    asked = learner.ask(10)
101
+    data = learner.data
102
+
103
+    del f
104
+    del learner
105
+
106
+    learner_loaded = serializer.loads(learner_bytes)
107
+    assert learner_loaded.npoints == 10
108
+    assert loss == learner_loaded.loss()
109
+    assert data == learner_loaded.data
110
+
111
+    assert asked == learner_loaded.ask(10)
112
+
113
+    # load again to undo the ask
114
+    learner_loaded = serializer.loads(learner_bytes)
115
+
116
+    simple(learner_loaded, goal_2)
117
+    assert learner_loaded.npoints == 20
... ...
@@ -51,8 +51,10 @@ extras_require = {
51 51
         "pre_commit",
52 52
     ],
53 53
     "other": [
54
-        "ipyparallel>=6.2.5",  # because of https://github.com/ipython/ipyparallel/issues/404
54
+        "cloudpickle",
55
+        "dill",
55 56
         "distributed",
57
+        "ipyparallel>=6.2.5",  # because of https://github.com/ipython/ipyparallel/issues/404
56 58
         "loky",
57 59
         "scikit-optimize",
58 60
         "wexpect" if os.name == "nt" else "pexpect",