Browse code

2D: make self._stack a OrderedDict

Bas Nijholt authored on 15/11/2017 23:56:35
Showing 1 changed files
... ...
@@ -1,4 +1,5 @@
1 1
 # -*- coding: utf-8 -*-
2
+import collections
2 3
 import itertools
3 4
 
4 5
 import holoviews as hv
... ...
@@ -114,7 +115,7 @@ class Learner2D(BaseLearner):
114 115
         self.bounds = tuple((float(a), float(b)) for a, b in bounds)
115 116
         self._points = np.zeros([100, self.ndim])
116 117
         self._values = np.zeros([100, self.vdim], dtype=float)
117
-        self._stack = []
118
+        self._stack = collections.OrderedDict()
118 119
         self._interp = {}
119 120
 
120 121
         xy_mean = np.mean(self.bounds, axis=1)
... ...
@@ -135,7 +136,7 @@ class Learner2D(BaseLearner):
135 136
         self._bounds_points = list(itertools.product(*bounds))
136 137
 
137 138
         # Add the loss improvement to the bounds in the stack
138
-        self._stack = [(*p, np.inf) for p in self._bounds_points]
139
+        self._stack.update({p: np.inf for p in self._bounds_points})
139 140
 
140 141
         self.function = function
141 142
 
... ...
@@ -223,10 +224,7 @@ class Learner2D(BaseLearner):
223 224
             self._values[n] = value
224 225
 
225 226
         # Remove the point if in the stack.
226
-        for i, (*_point, _) in enumerate(self._stack):
227
-            if point == tuple(_point):
228
-                self._stack.pop(i)
229
-                break
227
+        self._stack.pop(point, None)
230 228
 
231 229
     def _fill_stack(self, stack_till=None):
232 230
         if stack_till is None:
... ...
@@ -267,7 +265,7 @@ class Learner2D(BaseLearner):
267 265
                 continue
268 266
 
269 267
             # Add to stack
270
-            self._stack.append((*point_new, losses[jsimplex]))
268
+            self._stack[tuple(point_new)] = losses[jsimplex]
271 269
 
272 270
             if len(self._stack) >= stack_till:
273 271
                 break
... ...
@@ -275,12 +273,8 @@ class Learner2D(BaseLearner):
275 273
                 losses[jsimplex] = 0
276 274
 
277 275
     def _split_stack(self, n=None):
278
-        points = []
279
-        loss_improvements = []
280
-        for *point, loss_improvement in self._stack[:n]:
281
-            points.append(tuple(point))
282
-            loss_improvements.append(loss_improvement)
283
-        return points, loss_improvements
276
+        points, loss_improvements = zip(*reversed(self._stack.items()))
277
+        return points[:n], loss_improvements[:n]
284 278
 
285 279
     def _choose_and_add_points(self, n):
286 280
         if n <= len(self._stack):