Browse code

make SequenceLearner picklable

Bas Nijholt authored on 09/04/2020 22:34:35
Showing 1 changed files
... ...
@@ -83,16 +83,6 @@ class SequenceLearner(BaseLearner):
83 83
 
84 84
         return points, loss_improvements
85 85
 
86
-    def _get_data(self):
87
-        return self.data
88
-
89
-    def _set_data(self, data):
90
-        if data:
91
-            indices, values = zip(*data.items())
92
-            # the points aren't used by tell, so we can safely pass None
93
-            points = [(i, None) for i in indices]
94
-            self.tell_many(points, values)
95
-
96 86
     def loss(self, real=True):
97 87
         if not (self._to_do_indices or self.pending_points):
98 88
             return 0
... ...
@@ -128,3 +118,25 @@ class SequenceLearner(BaseLearner):
128 118
     @property
129 119
     def npoints(self):
130 120
         return len(self.data)
121
+
122
+    def _get_data(self):
123
+        return self.data
124
+
125
+    def _set_data(self, data):
126
+        if data:
127
+            indices, values = zip(*data.items())
128
+            # the points aren't used by tell, so we can safely pass None
129
+            points = [(i, None) for i in indices]
130
+            self.tell_many(points, values)
131
+
132
+    def __getstate__(self):
133
+        return (
134
+            self._original_function,
135
+            self.sequence,
136
+            self._get_data(),
137
+        )
138
+
139
+    def __setstate__(self, state):
140
+        function, sequence, data = state
141
+        self.__init__(function, sequence)
142
+        self._set_data(data)