Browse code

use a list for ((learner_index, point), loss_improvement) tuples

Bas Nijholt authored on 17/03/2019 12:06:09
Showing 1 changed files
... ...
@@ -113,40 +113,34 @@ class BalancingLearner(BaseLearner):
113 113
                 ' strategy="npoints" is implemented.')
114 114
 
115 115
     def _ask_and_tell_based_on_loss_improvements(self, n):
116
-        chosen_points = []
117
-        chosen_loss_improvements = []
118
-        npoints = [l.npoints + len(l.pending_points)
119
-                   for l in self.learners]
116
+        selected = []  # tuples ((learner_index, point), loss_improvement)
117
+        npoints = [l.npoints + len(l.pending_points) for l in self.learners]
120 118
         for _ in range(n):
121
-            improvements_per_learner = []
122
-            points_per_learner = []
119
+            to_select = []
123 120
             for index, learner in enumerate(self.learners):
124 121
                 # Take the points from the cache
125 122
                 if index not in self._ask_cache:
126 123
                     self._ask_cache[index] = learner.ask(
127 124
                         n=1, tell_pending=False)
128 125
                 points, loss_improvements = self._ask_cache[index]
129
-
130
-                priority = (loss_improvements[0], -npoints[index])
131
-                improvements_per_learner.append(priority)
132
-                points_per_learner.append((index, points[0]))
126
+                to_select.append(
127
+                    ((index, points[0]),
128
+                     (loss_improvements[0], -npoints[index]))
129
+                )
133 130
 
134 131
             # Choose the optimal improvement.
135 132
             (index, point), (loss_improvement, _) = max(
136
-                zip(points_per_learner, improvements_per_learner),
137
-                key=itemgetter(1))
133
+                to_select, key=itemgetter(1))
138 134
             npoints[index] += 1
139
-            chosen_points.append((index, point))
140
-            chosen_loss_improvements.append(loss_improvement)
135
+            selected.append(((index, point), loss_improvement))
141 136
             self.tell_pending((index, point))
142 137
 
143
-        return chosen_points, chosen_loss_improvements
138
+        points, loss_improvements = map(list, zip(*selected))
139
+        return points, loss_improvements
144 140
 
145 141
     def _ask_and_tell_based_on_loss(self, n):
146
-        chosen_points = []
147
-        chosen_loss_improvements = []
148
-        npoints = [l.npoints + len(l.pending_points)
149
-                   for l in self.learners]
142
+        selected = []   # tuples ((learner_index, point), loss_improvement)
143
+        npoints = [l.npoints + len(l.pending_points) for l in self.learners]
150 144
         for _ in range(n):
151 145
             losses = self._losses(real=False)
152 146
             priority = zip(losses, (-n for n in npoints))
... ...
@@ -158,15 +152,14 @@ class BalancingLearner(BaseLearner):
158 152
                 self._ask_cache[index] = self.learners[index].ask(n=1)
159 153
             points, loss_improvements = self._ask_cache[index]
160 154
 
161
-            chosen_points.append((index, points[0]))
162
-            chosen_loss_improvements.append(loss_improvements[0])
163
-        return chosen_points, chosen_loss_improvements
155
+            selected.append(((index, points[0]), loss_improvements[0]))
156
+
157
+        points, loss_improvements = map(list, zip(*selected))
158
+        return points, loss_improvements
164 159
 
165 160
     def _ask_and_tell_based_on_npoints(self, n):
166
-        chosen_points = []
167
-        chosen_loss_improvements = []
168
-        npoints = [l.npoints + len(l.pending_points)
169
-                   for l in self.learners]
161
+        selected = []  # tuples ((learner_index, point), loss_improvement)
162
+        npoints = [l.npoints + len(l.pending_points) for l in self.learners]
170 163
         n_left = n
171 164
         while n_left > 0:
172 165
             index = np.argmin(npoints)
... ...
@@ -176,9 +169,10 @@ class BalancingLearner(BaseLearner):
176 169
             points, loss_improvements = self._ask_cache[index]
177 170
             npoints[index] += 1
178 171
             n_left -= 1
179
-            chosen_points.append((index, points[0]))
180
-            chosen_loss_improvements.append(loss_improvements[0])
181
-        return chosen_points, chosen_loss_improvements
172
+            selected.append(((index, points[0]), loss_improvements[0]))
173
+
174
+        points, loss_improvements = map(list, zip(*selected))
175
+        return points, loss_improvements
182 176
 
183 177
     def ask(self, n, tell_pending=True):
184 178
         """Chose points for learners."""