... | ... |
@@ -71,7 +71,7 @@ class BalancingLearner(BaseLearner): |
71 | 71 |
for index, learner in enumerate(self.learners): |
72 | 72 |
if index not in self._points: |
73 | 73 |
self._points[index] = learner.ask( |
74 |
- n=1, add_data=False) |
|
74 |
+ n=1, tell_pending=False) |
|
75 | 75 |
point, loss_improvement = self._points[index] |
76 | 76 |
improvements_per_learner.append(loss_improvement[0]) |
77 | 77 |
pairs.append((index, point[0])) |
... | ... |
@@ -79,13 +79,13 @@ class BalancingLearner(BaseLearner): |
79 | 79 |
key=itemgetter(1)) |
80 | 80 |
points.append(x) |
81 | 81 |
loss_improvements.append(l) |
82 |
- self.tell(x, None) |
|
82 |
+ self.tell_pending(x) |
|
83 | 83 |
|
84 | 84 |
return points, loss_improvements |
85 | 85 |
|
86 |
- def ask(self, n, add_data=True): |
|
86 |
+ def ask(self, n, tell_pending=True): |
|
87 | 87 |
"""Chose points for learners.""" |
88 |
- if not add_data: |
|
88 |
+ if not tell_pending: |
|
89 | 89 |
with restore(*self.learners): |
90 | 90 |
return self._ask_and_tell(n) |
91 | 91 |
else: |
... | ... |
@@ -97,6 +97,12 @@ class BalancingLearner(BaseLearner): |
97 | 97 |
self._loss.pop(index, None) |
98 | 98 |
self.learners[index].tell(x, y) |
99 | 99 |
|
100 |
+ def tell_pending(self, x): |
|
101 |
+ index, x = x |
|
102 |
+ self._points.pop(index, None) |
|
103 |
+ self._loss.pop(index, None) |
|
104 |
+ self.learners[index].tell_pending(x) |
|
105 |
+ |
|
100 | 106 |
def loss(self, real=True): |
101 | 107 |
losses = [] |
102 | 108 |
for index, learner in enumerate(self.learners): |