Browse code

move pid logic to _ask

Bas Nijholt authored on 22/04/2020 20:36:56
Showing 1 changed files
... ...
@@ -175,10 +175,10 @@ class BaseRunner(metaclass=abc.ABCMeta):
175 175
 
176 176
     def _ask(self, n):
177 177
         points = []
178
-        for i, _id in enumerate(self._to_retry.keys()):
178
+        for i, pid in enumerate(self._to_retry.keys()):
179 179
             if i == n:
180 180
                 break
181
-            point = self._id_to_point[_id]
181
+            point = self._id_to_point[pid]
182 182
             if point not in self.pending_points.values():
183 183
                 points.append(point)
184 184
 
... ...
@@ -187,6 +187,8 @@ class BaseRunner(metaclass=abc.ABCMeta):
187 187
             new_points, new_losses = self.learner.ask(n - len(points))
188 188
             points += new_points
189 189
             loss_improvements += new_losses
190
+            for p in new_points:
191
+                self._id_to_point[self._next_id()] = p
190 192
         return points, loss_improvements
191 193
 
192 194
     def overhead(self):
... ...
@@ -251,11 +253,6 @@ class BaseRunner(metaclass=abc.ABCMeta):
251 253
             fut = self._submit(x)
252 254
             fut.start_time = start_time
253 255
             self.pending_points[fut] = x
254
-            try:
255
-                _id = _key_by_value(self._id_to_point, x)  # O(N)
256
-            except StopIteration:  # `x` is not a value in `self._id_to_point`
257
-                _id = self._next_id()
258
-            self._id_to_point[_id] = x
259 256
 
260 257
         # Collect and results and add them to the learner
261 258
         futures = list(self.pending_points.keys())