... | ... |
@@ -54,6 +54,12 @@ _default_executor = ( |
54 | 54 |
) |
55 | 55 |
|
56 | 56 |
|
57 |
+def _key_by_value(dct, value): |
|
58 |
+ for k, v in dct.items(): |
|
59 |
+ if v == value: |
|
60 |
+ return k |
|
61 |
+ |
|
62 |
+ |
|
57 | 63 |
class BaseRunner(metaclass=abc.ABCMeta): |
58 | 64 |
r"""Base class for runners that use `concurrent.futures.Executors`. |
59 | 65 |
|
... | ... |
@@ -146,11 +152,16 @@ class BaseRunner(metaclass=abc.ABCMeta): |
146 | 152 |
self.to_retry = {} |
147 | 153 |
self.tracebacks = {} |
148 | 154 |
|
155 |
+ # Keeping track of index -> point |
|
156 |
+ self.index_to_point = {} |
|
157 |
+ self._i = 0 # some unique index to be associated with each point |
|
158 |
+ |
|
149 | 159 |
def _get_max_tasks(self): |
150 | 160 |
return self._max_tasks or _get_ncores(self.executor) |
151 | 161 |
|
152 |
- def _do_raise(self, e, x): |
|
153 |
- tb = self.tracebacks[x] |
|
162 |
+ def _do_raise(self, e, i): |
|
163 |
+ tb = self.tracebacks[i] |
|
164 |
+ x = self.index_to_point[i] |
|
154 | 165 |
raise RuntimeError( |
155 | 166 |
"An error occured while evaluating " |
156 | 167 |
f'"learner.function({x})". ' |
... | ... |
@@ -162,9 +173,14 @@ class BaseRunner(metaclass=abc.ABCMeta): |
162 | 173 |
return self.log is not None |
163 | 174 |
|
164 | 175 |
def _ask(self, n): |
165 |
- points = [ |
|
166 |
- p for p in self.to_retry.keys() if p not in self.pending_points.values() |
|
167 |
- ][:n] |
|
176 |
+ points = [] |
|
177 |
+ for i, index in enumerate(self.to_retry.keys()): |
|
178 |
+ if i == n: |
|
179 |
+ break |
|
180 |
+ point = self.index_to_point[index] |
|
181 |
+ if point not in self.pending_points.values(): |
|
182 |
+ points.append(point) |
|
183 |
+ |
|
168 | 184 |
loss_improvements = len(points) * [float("inf")] |
169 | 185 |
if len(points) < n: |
170 | 186 |
new_points, new_losses = self.learner.ask(n - len(points)) |
... | ... |
@@ -198,20 +214,22 @@ class BaseRunner(metaclass=abc.ABCMeta): |
198 | 214 |
def _process_futures(self, done_futs): |
199 | 215 |
for fut in done_futs: |
200 | 216 |
x = self.pending_points.pop(fut) |
217 |
+ i = _key_by_value(self.index_to_point, x) # O(N) |
|
201 | 218 |
try: |
202 | 219 |
y = fut.result() |
203 | 220 |
t = time.time() - fut.start_time # total execution time |
204 | 221 |
except Exception as e: |
205 |
- self.tracebacks[x] = traceback.format_exc() |
|
206 |
- self.to_retry[x] = self.to_retry.get(x, 0) + 1 |
|
207 |
- if self.to_retry[x] > self.retries: |
|
208 |
- self.to_retry.pop(x) |
|
222 |
+ self.tracebacks[i] = traceback.format_exc() |
|
223 |
+ self.to_retry[i] = self.to_retry.get(i, 0) + 1 |
|
224 |
+ if self.to_retry[i] > self.retries: |
|
225 |
+ self.to_retry.pop(i) |
|
209 | 226 |
if self.raise_if_retries_exceeded: |
210 |
- self._do_raise(e, x) |
|
227 |
+ self._do_raise(e, i) |
|
211 | 228 |
else: |
212 | 229 |
self._elapsed_function_time += t / self._get_max_tasks() |
213 |
- self.to_retry.pop(x, None) |
|
214 |
- self.tracebacks.pop(x, None) |
|
230 |
+ self.to_retry.pop(i, None) |
|
231 |
+ self.tracebacks.pop(i, None) |
|
232 |
+ self.index_to_point.pop(i) |
|
215 | 233 |
if self.do_log: |
216 | 234 |
self.log.append(("tell", x, y)) |
217 | 235 |
self.learner.tell(x, y) |
... | ... |
@@ -232,6 +250,12 @@ class BaseRunner(metaclass=abc.ABCMeta): |
232 | 250 |
fut = self._submit(x) |
233 | 251 |
fut.start_time = start_time |
234 | 252 |
self.pending_points[fut] = x |
253 |
+ i = _key_by_value(self.index_to_point, x) # O(N) |
|
254 |
+ if i is None: |
|
255 |
+ # `x` is not a value in `self.index_to_point` |
|
256 |
+ self._i += 1 |
|
257 |
+ i = self._i |
|
258 |
+ self.index_to_point[i] = x |
|
235 | 259 |
|
236 | 260 |
# Collect and results and add them to the learner |
237 | 261 |
futures = list(self.pending_points.keys()) |