Browse code

make the Runner work with unhashable points

Bas Nijholt authored on 16/04/2020 18:38:03
Showing 1 changed files
... ...
@@ -54,6 +54,12 @@ _default_executor = (
54 54
 )
55 55
 
56 56
 
57
+def _key_by_value(dct, value):
58
+    for k, v in dct.items():
59
+        if v == value:
60
+            return k
61
+
62
+
57 63
 class BaseRunner(metaclass=abc.ABCMeta):
58 64
     r"""Base class for runners that use `concurrent.futures.Executors`.
59 65
 
... ...
@@ -146,11 +152,16 @@ class BaseRunner(metaclass=abc.ABCMeta):
146 152
         self.to_retry = {}
147 153
         self.tracebacks = {}
148 154
 
155
+        # Keeping track of index -> point
156
+        self.index_to_point = {}
157
+        self._i = 0  # some unique index to be associated with each point
158
+
149 159
     def _get_max_tasks(self):
150 160
         return self._max_tasks or _get_ncores(self.executor)
151 161
 
152
-    def _do_raise(self, e, x):
153
-        tb = self.tracebacks[x]
162
+    def _do_raise(self, e, i):
163
+        tb = self.tracebacks[i]
164
+        x = self.index_to_point[i]
154 165
         raise RuntimeError(
155 166
             "An error occured while evaluating "
156 167
             f'"learner.function({x})". '
... ...
@@ -162,9 +173,14 @@ class BaseRunner(metaclass=abc.ABCMeta):
162 173
         return self.log is not None
163 174
 
164 175
     def _ask(self, n):
165
-        points = [
166
-            p for p in self.to_retry.keys() if p not in self.pending_points.values()
167
-        ][:n]
176
+        points = []
177
+        for i, index in enumerate(self.to_retry.keys()):
178
+            if i == n:
179
+                break
180
+            point = self.index_to_point[index]
181
+            if point not in self.pending_points.values():
182
+                points.append(point)
183
+
168 184
         loss_improvements = len(points) * [float("inf")]
169 185
         if len(points) < n:
170 186
             new_points, new_losses = self.learner.ask(n - len(points))
... ...
@@ -198,20 +214,22 @@ class BaseRunner(metaclass=abc.ABCMeta):
198 214
     def _process_futures(self, done_futs):
199 215
         for fut in done_futs:
200 216
             x = self.pending_points.pop(fut)
217
+            i = _key_by_value(self.index_to_point, x)  # O(N)
201 218
             try:
202 219
                 y = fut.result()
203 220
                 t = time.time() - fut.start_time  # total execution time
204 221
             except Exception as e:
205
-                self.tracebacks[x] = traceback.format_exc()
206
-                self.to_retry[x] = self.to_retry.get(x, 0) + 1
207
-                if self.to_retry[x] > self.retries:
208
-                    self.to_retry.pop(x)
222
+                self.tracebacks[i] = traceback.format_exc()
223
+                self.to_retry[i] = self.to_retry.get(i, 0) + 1
224
+                if self.to_retry[i] > self.retries:
225
+                    self.to_retry.pop(i)
209 226
                     if self.raise_if_retries_exceeded:
210
-                        self._do_raise(e, x)
227
+                        self._do_raise(e, i)
211 228
             else:
212 229
                 self._elapsed_function_time += t / self._get_max_tasks()
213
-                self.to_retry.pop(x, None)
214
-                self.tracebacks.pop(x, None)
230
+                self.to_retry.pop(i, None)
231
+                self.tracebacks.pop(i, None)
232
+                self.index_to_point.pop(i)
215 233
                 if self.do_log:
216 234
                     self.log.append(("tell", x, y))
217 235
                 self.learner.tell(x, y)
... ...
@@ -232,6 +250,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
232 250
             fut = self._submit(x)
233 251
             fut.start_time = start_time
234 252
             self.pending_points[fut] = x
253
+            i = _key_by_value(self.index_to_point, x)  # O(N)
254
+            if i is None:
255
+                # `x` is not a value in `self.index_to_point`
256
+                self._i += 1
257
+                i = self._i
258
+            self.index_to_point[i] = x
235 259
 
236 260
         # Collect and results and add them to the learner
237 261
         futures = list(self.pending_points.keys())