Browse code

BalancingLearner: create 'tell_pending' which deprecates 'tell(x, None)'

Bas Nijholt authored on 20/09/2018 13:22:09
Showing 1 changed files
... ...
@@ -71,7 +71,7 @@ class BalancingLearner(BaseLearner):
71 71
             for index, learner in enumerate(self.learners):
72 72
                 if index not in self._points:
73 73
                     self._points[index] = learner.ask(
74
-                        n=1, add_data=False)
74
+                        n=1, tell_pending=False)
75 75
                 point, loss_improvement = self._points[index]
76 76
                 improvements_per_learner.append(loss_improvement[0])
77 77
                 pairs.append((index, point[0]))
... ...
@@ -79,13 +79,13 @@ class BalancingLearner(BaseLearner):
79 79
                        key=itemgetter(1))
80 80
             points.append(x)
81 81
             loss_improvements.append(l)
82
-            self.tell(x, None)
82
+            self.tell_pending(x)
83 83
 
84 84
         return points, loss_improvements
85 85
 
86
-    def ask(self, n, add_data=True):
86
+    def ask(self, n, tell_pending=True):
87 87
         """Chose points for learners."""
88
-        if not add_data:
88
+        if not tell_pending:
89 89
             with restore(*self.learners):
90 90
                 return self._ask_and_tell(n)
91 91
         else:
... ...
@@ -97,6 +97,12 @@ class BalancingLearner(BaseLearner):
97 97
         self._loss.pop(index, None)
98 98
         self.learners[index].tell(x, y)
99 99
 
100
+    def tell_pending(self, x):
101
+        index, x = x
102
+        self._points.pop(index, None)
103
+        self._loss.pop(index, None)
104
+        self.learners[index].tell_pending(x)
105
+
100 106
     def loss(self, real=True):
101 107
         losses = []
102 108
         for index, learner in enumerate(self.learners):