...
|
...
|
@@ -125,12 +125,14 @@ class BalancingLearner(BaseLearner):
|
125
|
125
|
self._points[index] = learner.ask(
|
126
|
126
|
n=1, tell_pending=False)
|
127
|
127
|
points, loss_improvements = self._points[index]
|
128
|
|
- npoints = npoints_per_learner[index] + learner.npoints
|
|
128
|
+ npoints = (npoints_per_learner[index]
|
|
129
|
+ + learner.npoints
|
|
130
|
+ + len(learner.pending_points))
|
129
|
131
|
priority = (loss_improvements[0], -npoints)
|
130
|
132
|
improvements_per_learner.append(priority)
|
131
|
133
|
points_per_learner.append((index, points[0]))
|
132
|
134
|
|
133
|
|
- # Chose the optimal improvement.
|
|
135
|
+ # Choose the optimal improvement.
|
134
|
136
|
(index, point), (loss_improvement, _) = max(
|
135
|
137
|
zip(points_per_learner, improvements_per_learner),
|
136
|
138
|
key=itemgetter(1))
|
...
|
...
|
@@ -142,15 +144,23 @@ class BalancingLearner(BaseLearner):
|
142
|
144
|
return chosen_points, chosen_loss_improvements
|
143
|
145
|
|
144
|
146
|
def _ask_and_tell_based_on_loss(self, n):
|
145
|
|
- points = []
|
146
|
|
- loss_improvements = []
|
|
147
|
+ chosen_points = []
|
|
148
|
+ chosen_loss_improvements = []
|
|
149
|
+ npoints_per_learner = defaultdict(int)
|
|
150
|
+
|
147
|
151
|
for _ in range(n):
|
148
|
152
|
losses = self._losses(real=False)
|
149
|
|
- max_ind = np.argmax(losses)
|
150
|
|
- xs, ls = self.learners[max_ind].ask(1)
|
151
|
|
- points.append((max_ind, xs[0]))
|
152
|
|
- loss_improvements.append(ls[0])
|
153
|
|
- return points, loss_improvements
|
|
153
|
+ npoints = [-(l.npoints
|
|
154
|
+ + npoints_per_learner[i]
|
|
155
|
+ + len(l.pending_points))
|
|
156
|
+ for i, l in enumerate(self.learners)]
|
|
157
|
+ priority = zip(losses, npoints)
|
|
158
|
+ index, (_, _) = max(enumerate(priority), key=itemgetter(1))
|
|
159
|
+ npoints_per_learner[index] += 1
|
|
160
|
+ points, loss_improvements = self.learners[index].ask(1)
|
|
161
|
+ chosen_points.append((index, points[0]))
|
|
162
|
+ chosen_loss_improvements.append(loss_improvements[0])
|
|
163
|
+ return chosen_points, chosen_loss_improvements
|
154
|
164
|
|
155
|
165
|
def _ask_and_tell_based_on_npoints(self, n):
|
156
|
166
|
points = []
|