Browse code

fix point distribution for loss strategy

Bas Nijholt authored on 17/03/2019 11:36:15
Showing 1 changed files
... ...
@@ -125,12 +125,14 @@ class BalancingLearner(BaseLearner):
125 125
                     self._points[index] = learner.ask(
126 126
                         n=1, tell_pending=False)
127 127
                 points, loss_improvements = self._points[index]
128
-                npoints = npoints_per_learner[index] + learner.npoints
128
+                npoints = (npoints_per_learner[index]
129
+                           + learner.npoints
130
+                           + len(learner.pending_points))
129 131
                 priority = (loss_improvements[0], -npoints)
130 132
                 improvements_per_learner.append(priority)
131 133
                 points_per_learner.append((index, points[0]))
132 134
 
133
-            # Chose the optimal improvement.
135
+            # Choose the optimal improvement.
134 136
             (index, point), (loss_improvement, _) = max(
135 137
                 zip(points_per_learner, improvements_per_learner),
136 138
                 key=itemgetter(1))
... ...
@@ -142,15 +144,23 @@ class BalancingLearner(BaseLearner):
142 144
         return chosen_points, chosen_loss_improvements
143 145
 
144 146
     def _ask_and_tell_based_on_loss(self, n):
145
-        points = []
146
-        loss_improvements = []
147
+        chosen_points = []
148
+        chosen_loss_improvements = []
149
+        npoints_per_learner = defaultdict(int)
150
+
147 151
         for _ in range(n):
148 152
             losses = self._losses(real=False)
149
-            max_ind = np.argmax(losses)
150
-            xs, ls = self.learners[max_ind].ask(1)
151
-            points.append((max_ind, xs[0]))
152
-            loss_improvements.append(ls[0])
153
-        return points, loss_improvements
153
+            npoints = [-(l.npoints
154
+                         + npoints_per_learner[i]
155
+                         + len(l.pending_points))
156
+                       for i, l in enumerate(self.learners)]
157
+            priority = zip(losses, npoints)
158
+            index, (_, _) = max(enumerate(priority), key=itemgetter(1))
159
+            npoints_per_learner[index] += 1
160
+            points, loss_improvements = self.learners[index].ask(1)
161
+            chosen_points.append((index, points[0]))
162
+            chosen_loss_improvements.append(loss_improvements[0])
163
+        return chosen_points, chosen_loss_improvements
154 164
 
155 165
     def _ask_and_tell_based_on_npoints(self, n):
156 166
         points = []