...
|
...
|
@@ -175,10 +175,10 @@ class BaseRunner(metaclass=abc.ABCMeta):
|
175
|
175
|
|
176
|
176
|
def _ask(self, n):
|
177
|
177
|
points = []
|
178
|
|
- for i, _id in enumerate(self._to_retry.keys()):
|
|
178
|
+ for i, pid in enumerate(self._to_retry.keys()):
|
179
|
179
|
if i == n:
|
180
|
180
|
break
|
181
|
|
- point = self._id_to_point[_id]
|
|
181
|
+ point = self._id_to_point[pid]
|
182
|
182
|
if point not in self.pending_points.values():
|
183
|
183
|
points.append(point)
|
184
|
184
|
|
...
|
...
|
@@ -187,6 +187,8 @@ class BaseRunner(metaclass=abc.ABCMeta):
|
187
|
187
|
new_points, new_losses = self.learner.ask(n - len(points))
|
188
|
188
|
points += new_points
|
189
|
189
|
loss_improvements += new_losses
|
|
190
|
+ for p in new_points:
|
|
191
|
+ self._id_to_point[self._next_id()] = p
|
190
|
192
|
return points, loss_improvements
|
191
|
193
|
|
192
|
194
|
def overhead(self):
|
...
|
...
|
@@ -251,11 +253,6 @@ class BaseRunner(metaclass=abc.ABCMeta):
|
251
|
253
|
fut = self._submit(x)
|
252
|
254
|
fut.start_time = start_time
|
253
|
255
|
self.pending_points[fut] = x
|
254
|
|
- try:
|
255
|
|
- _id = _key_by_value(self._id_to_point, x) # O(N)
|
256
|
|
- except StopIteration: # `x` is not a value in `self._id_to_point`
|
257
|
|
- _id = self._next_id()
|
258
|
|
- self._id_to_point[_id] = x
|
259
|
256
|
|
260
|
257
|
# Collect and results and add them to the learner
|
261
|
258
|
futures = list(self.pending_points.keys())
|