... | ... |
@@ -1,4 +1,5 @@ |
1 | 1 |
# -*- coding: utf-8 -*- |
2 |
+import collections |
|
2 | 3 |
import itertools |
3 | 4 |
|
4 | 5 |
import holoviews as hv |
... | ... |
@@ -114,7 +115,7 @@ class Learner2D(BaseLearner): |
114 | 115 |
self.bounds = tuple((float(a), float(b)) for a, b in bounds) |
115 | 116 |
self._points = np.zeros([100, self.ndim]) |
116 | 117 |
self._values = np.zeros([100, self.vdim], dtype=float) |
117 |
- self._stack = [] |
|
118 |
+ self._stack = collections.OrderedDict() |
|
118 | 119 |
self._interp = {} |
119 | 120 |
|
120 | 121 |
xy_mean = np.mean(self.bounds, axis=1) |
... | ... |
@@ -135,7 +136,7 @@ class Learner2D(BaseLearner): |
135 | 136 |
self._bounds_points = list(itertools.product(*bounds)) |
136 | 137 |
|
137 | 138 |
# Add the loss improvement to the bounds in the stack |
138 |
- self._stack = [(*p, np.inf) for p in self._bounds_points] |
|
139 |
+ self._stack.update({p: np.inf for p in self._bounds_points}) |
|
139 | 140 |
|
140 | 141 |
self.function = function |
141 | 142 |
|
... | ... |
@@ -223,10 +224,7 @@ class Learner2D(BaseLearner): |
223 | 224 |
self._values[n] = value |
224 | 225 |
|
225 | 226 |
# Remove the point if in the stack. |
226 |
- for i, (*_point, _) in enumerate(self._stack): |
|
227 |
- if point == tuple(_point): |
|
228 |
- self._stack.pop(i) |
|
229 |
- break |
|
227 |
+ self._stack.pop(point, None) |
|
230 | 228 |
|
231 | 229 |
def _fill_stack(self, stack_till=None): |
232 | 230 |
if stack_till is None: |
... | ... |
@@ -267,7 +265,7 @@ class Learner2D(BaseLearner): |
267 | 265 |
continue |
268 | 266 |
|
269 | 267 |
# Add to stack |
270 |
- self._stack.append((*point_new, losses[jsimplex])) |
|
268 |
+ self._stack[tuple(point_new)] = losses[jsimplex] |
|
271 | 269 |
|
272 | 270 |
if len(self._stack) >= stack_till: |
273 | 271 |
break |
... | ... |
@@ -275,12 +273,8 @@ class Learner2D(BaseLearner): |
275 | 273 |
losses[jsimplex] = 0 |
276 | 274 |
|
277 | 275 |
def _split_stack(self, n=None): |
278 |
- points = [] |
|
279 |
- loss_improvements = [] |
|
280 |
- for *point, loss_improvement in self._stack[:n]: |
|
281 |
- points.append(tuple(point)) |
|
282 |
- loss_improvements.append(loss_improvement) |
|
283 |
- return points, loss_improvements |
|
276 |
+ points, loss_improvements = zip(*reversed(self._stack.items())) |
|
277 |
+ return points[:n], loss_improvements[:n] |
|
284 | 278 |
|
285 | 279 |
def _choose_and_add_points(self, n): |
286 | 280 |
if n <= len(self._stack): |