Browse code

make pending_points a mapping of future -> pid

Bas Nijholt authored on 22/04/2020 22:52:17
Showing 1 changed files
... ...
@@ -174,22 +174,21 @@ class BaseRunner(metaclass=abc.ABCMeta):
174 174
         return self.log is not None
175 175
 
176 176
     def _ask(self, n):
177
-        points = []
178
-        for i, pid in enumerate(self._to_retry.keys()):
179
-            if i == n:
180
-                break
181
-            point = self._id_to_point[pid]
182
-            if point not in self.pending_points.values():
183
-                points.append(point)
184
-
185
-        loss_improvements = len(points) * [float("inf")]
186
-        if len(points) < n:
187
-            new_points, new_losses = self.learner.ask(n - len(points))
188
-            points += new_points
177
+        pids = [
178
+            pid
179
+            for pid in self._to_retry.keys()
180
+            if pid not in self.pending_points.values()
181
+        ][:n]
182
+        loss_improvements = len(pids) * [float("inf")]
183
+
184
+        if len(pids) < n:
185
+            new_points, new_losses = self.learner.ask(n - len(pids))
189 186
             loss_improvements += new_losses
190
-            for p in new_points:
191
-                self._id_to_point[self._next_id()] = p
192
-        return points, loss_improvements
187
+            for point in new_points:
188
+                pid = self._next_id()
189
+                self._id_to_point[pid] = point
190
+                pids.append(pid)
191
+        return pids, loss_improvements
193 192
 
194 193
     def overhead(self):
195 194
         """Overhead of using Adaptive and the executor in percent.
... ...
@@ -216,23 +215,22 @@ class BaseRunner(metaclass=abc.ABCMeta):
216 215
 
217 216
     def _process_futures(self, done_futs):
218 217
         for fut in done_futs:
219
-            x = self.pending_points.pop(fut)
220
-            i = _key_by_value(self._id_to_point, x)  # O(N)
218
+            pid = self.pending_points.pop(fut)
221 219
             try:
222 220
                 y = fut.result()
223 221
                 t = time.time() - fut.start_time  # total execution time
224 222
             except Exception as e:
225
-                self._tracebacks[i] = traceback.format_exc()
226
-                self._to_retry[i] = self._to_retry.get(i, 0) + 1
227
-                if self._to_retry[i] > self.retries:
228
-                    self._to_retry.pop(i)
223
+                self._tracebacks[pid] = traceback.format_exc()
224
+                self._to_retry[pid] = self._to_retry.get(pid, 0) + 1
225
+                if self._to_retry[pid] > self.retries:
226
+                    self._to_retry.pop(pid)
229 227
                     if self.raise_if_retries_exceeded:
230
-                        self._do_raise(e, i)
228
+                        self._do_raise(e, pid)
231 229
             else:
232 230
                 self._elapsed_function_time += t / self._get_max_tasks()
233
-                self._to_retry.pop(i, None)
234
-                self._tracebacks.pop(i, None)
235
-                self._id_to_point.pop(i)
231
+                self._to_retry.pop(pid, None)
232
+                self._tracebacks.pop(pid, None)
233
+                x = self._id_to_point.pop(pid)
236 234
                 if self.do_log:
237 235
                     self.log.append(("tell", x, y))
238 236
                 self.learner.tell(x, y)
... ...
@@ -246,13 +244,14 @@ class BaseRunner(metaclass=abc.ABCMeta):
246 244
         if self.do_log:
247 245
             self.log.append(("ask", n_new_tasks))
248 246
 
249
-        points, _ = self._ask(n_new_tasks)
247
+        pids, _ = self._ask(n_new_tasks)
250 248
 
251
-        for x in points:
249
+        for pid in pids:
252 250
             start_time = time.time()  # so we can measure execution time
253
-            fut = self._submit(x)
251
+            point = self._id_to_point[pid]
252
+            fut = self._submit(point)
254 253
             fut.start_time = start_time
255
-            self.pending_points[fut] = x
254
+            self.pending_points[fut] = pid
256 255
 
257 256
         # Collect and results and add them to the learner
258 257
         futures = list(self.pending_points.keys())