... | ... |
@@ -174,22 +174,21 @@ class BaseRunner(metaclass=abc.ABCMeta): |
174 | 174 |
return self.log is not None |
175 | 175 |
|
176 | 176 |
def _ask(self, n): |
177 |
- points = [] |
|
178 |
- for i, pid in enumerate(self._to_retry.keys()): |
|
179 |
- if i == n: |
|
180 |
- break |
|
181 |
- point = self._id_to_point[pid] |
|
182 |
- if point not in self.pending_points.values(): |
|
183 |
- points.append(point) |
|
184 |
- |
|
185 |
- loss_improvements = len(points) * [float("inf")] |
|
186 |
- if len(points) < n: |
|
187 |
- new_points, new_losses = self.learner.ask(n - len(points)) |
|
188 |
- points += new_points |
|
177 |
+ pids = [ |
|
178 |
+ pid |
|
179 |
+ for pid in self._to_retry.keys() |
|
180 |
+ if pid not in self.pending_points.values() |
|
181 |
+ ][:n] |
|
182 |
+ loss_improvements = len(pids) * [float("inf")] |
|
183 |
+ |
|
184 |
+ if len(pids) < n: |
|
185 |
+ new_points, new_losses = self.learner.ask(n - len(pids)) |
|
189 | 186 |
loss_improvements += new_losses |
190 |
- for p in new_points: |
|
191 |
- self._id_to_point[self._next_id()] = p |
|
192 |
- return points, loss_improvements |
|
187 |
+ for point in new_points: |
|
188 |
+ pid = self._next_id() |
|
189 |
+ self._id_to_point[pid] = point |
|
190 |
+ pids.append(pid) |
|
191 |
+ return pids, loss_improvements |
|
193 | 192 |
|
194 | 193 |
def overhead(self): |
195 | 194 |
"""Overhead of using Adaptive and the executor in percent. |
... | ... |
@@ -216,23 +215,22 @@ class BaseRunner(metaclass=abc.ABCMeta): |
216 | 215 |
|
217 | 216 |
def _process_futures(self, done_futs): |
218 | 217 |
for fut in done_futs: |
219 |
- x = self.pending_points.pop(fut) |
|
220 |
- i = _key_by_value(self._id_to_point, x) # O(N) |
|
218 |
+ pid = self.pending_points.pop(fut) |
|
221 | 219 |
try: |
222 | 220 |
y = fut.result() |
223 | 221 |
t = time.time() - fut.start_time # total execution time |
224 | 222 |
except Exception as e: |
225 |
- self._tracebacks[i] = traceback.format_exc() |
|
226 |
- self._to_retry[i] = self._to_retry.get(i, 0) + 1 |
|
227 |
- if self._to_retry[i] > self.retries: |
|
228 |
- self._to_retry.pop(i) |
|
223 |
+ self._tracebacks[pid] = traceback.format_exc() |
|
224 |
+ self._to_retry[pid] = self._to_retry.get(pid, 0) + 1 |
|
225 |
+ if self._to_retry[pid] > self.retries: |
|
226 |
+ self._to_retry.pop(pid) |
|
229 | 227 |
if self.raise_if_retries_exceeded: |
230 |
- self._do_raise(e, i) |
|
228 |
+ self._do_raise(e, pid) |
|
231 | 229 |
else: |
232 | 230 |
self._elapsed_function_time += t / self._get_max_tasks() |
233 |
- self._to_retry.pop(i, None) |
|
234 |
- self._tracebacks.pop(i, None) |
|
235 |
- self._id_to_point.pop(i) |
|
231 |
+ self._to_retry.pop(pid, None) |
|
232 |
+ self._tracebacks.pop(pid, None) |
|
233 |
+ x = self._id_to_point.pop(pid) |
|
236 | 234 |
if self.do_log: |
237 | 235 |
self.log.append(("tell", x, y)) |
238 | 236 |
self.learner.tell(x, y) |
... | ... |
@@ -246,13 +244,14 @@ class BaseRunner(metaclass=abc.ABCMeta): |
246 | 244 |
if self.do_log: |
247 | 245 |
self.log.append(("ask", n_new_tasks)) |
248 | 246 |
|
249 |
- points, _ = self._ask(n_new_tasks) |
|
247 |
+ pids, _ = self._ask(n_new_tasks) |
|
250 | 248 |
|
251 |
- for x in points: |
|
249 |
+ for pid in pids: |
|
252 | 250 |
start_time = time.time() # so we can measure execution time |
253 |
- fut = self._submit(x) |
|
251 |
+ point = self._id_to_point[pid] |
|
252 |
+ fut = self._submit(point) |
|
254 | 253 |
fut.start_time = start_time |
255 |
- self.pending_points[fut] = x |
|
254 |
+ self.pending_points[fut] = pid |
|
256 | 255 |
|
257 | 256 |
# Collect and results and add them to the learner |
258 | 257 |
futures = list(self.pending_points.keys()) |