...
|
...
|
@@ -77,7 +77,7 @@ class BalancingLearner(BaseLearner):
|
77
|
77
|
# pickle the whole learner.
|
78
|
78
|
self.function = partial(dispatch, [l.function for l in self.learners])
|
79
|
79
|
|
80
|
|
- self._points = {}
|
|
80
|
+ self._ask_cache = {}
|
81
|
81
|
self._loss = {}
|
82
|
82
|
self._pending_loss = {}
|
83
|
83
|
self._cdims_default = cdims
|
...
|
...
|
@@ -122,10 +122,10 @@ class BalancingLearner(BaseLearner):
|
122
|
122
|
points_per_learner = []
|
123
|
123
|
for index, learner in enumerate(self.learners):
|
124
|
124
|
# Take the points from the cache
|
125
|
|
- if index not in self._points:
|
126
|
|
- self._points[index] = learner.ask(
|
|
125
|
+ if index not in self._ask_cache:
|
|
126
|
+ self._ask_cache[index] = learner.ask(
|
127
|
127
|
n=1, tell_pending=False)
|
128
|
|
- points, loss_improvements = self._points[index]
|
|
128
|
+ points, loss_improvements = self._ask_cache[index]
|
129
|
129
|
|
130
|
130
|
priority = (loss_improvements[0], -npoints[index])
|
131
|
131
|
improvements_per_learner.append(priority)
|
...
|
...
|
@@ -154,9 +154,9 @@ class BalancingLearner(BaseLearner):
|
154
|
154
|
npoints[index] += 1
|
155
|
155
|
|
156
|
156
|
# Take the points from the cache
|
157
|
|
- if index not in self._points:
|
158
|
|
- self._points[index] = self.learners[index].ask(n=1)
|
159
|
|
- points, loss_improvements = self._points[index]
|
|
157
|
+ if index not in self._ask_cache:
|
|
158
|
+ self._ask_cache[index] = self.learners[index].ask(n=1)
|
|
159
|
+ points, loss_improvements = self._ask_cache[index]
|
160
|
160
|
|
161
|
161
|
chosen_points.append((index, points[0]))
|
162
|
162
|
chosen_loss_improvements.append(loss_improvements[0])
|
...
|
...
|
@@ -171,9 +171,9 @@ class BalancingLearner(BaseLearner):
|
171
|
171
|
while n_left > 0:
|
172
|
172
|
index = np.argmin(npoints)
|
173
|
173
|
# Take the points from the cache
|
174
|
|
- if index not in self._points:
|
175
|
|
- self._points[index] = self.learners[index].ask(n=1)
|
176
|
|
- points, loss_improvements = self._points[index]
|
|
174
|
+ if index not in self._ask_cache:
|
|
175
|
+ self._ask_cache[index] = self.learners[index].ask(n=1)
|
|
176
|
+ points, loss_improvements = self._ask_cache[index]
|
177
|
177
|
npoints[index] += 1
|
178
|
178
|
n_left -= 1
|
179
|
179
|
chosen_points.append((index, points[0]))
|
...
|
...
|
@@ -190,14 +190,14 @@ class BalancingLearner(BaseLearner):
|
190
|
190
|
|
191
|
191
|
def tell(self, x, y):
|
192
|
192
|
index, x = x
|
193
|
|
- self._points.pop(index, None)
|
|
193
|
+ self._ask_cache.pop(index, None)
|
194
|
194
|
self._loss.pop(index, None)
|
195
|
195
|
self._pending_loss.pop(index, None)
|
196
|
196
|
self.learners[index].tell(x, y)
|
197
|
197
|
|
198
|
198
|
def tell_pending(self, x):
|
199
|
199
|
index, x = x
|
200
|
|
- self._points.pop(index, None)
|
|
200
|
+ self._ask_cache.pop(index, None)
|
201
|
201
|
self._loss.pop(index, None)
|
202
|
202
|
self.learners[index].tell_pending(x)
|
203
|
203
|
|