Browse code

improve code homogeneity

Bas Nijholt authored on 20/06/2018 02:22:40
Showing 1 changed files
... ...
@@ -130,33 +130,34 @@ class Learner1D(BaseLearner):
130 130
         else:
131 131
             return max(losses.values())
132 132
 
133
-    def update_interpolated_losses_in_interval(self, x_lower, x_upper):
134
-        if x_lower is not None and x_upper is not None:
135
-            dx = x_upper - x_lower
136
-            loss = self.loss_per_interval((x_lower, x_upper), self._scale, self.data)
137
-            self.losses[x_lower, x_upper] = loss
138
-
139
-            start = self.neighbors_combined.bisect_right(x_lower)
140
-            end = self.neighbors_combined.bisect_left(x_upper)
133
+    def update_interpolated_loss_in_interval(self, x_left, x_right):
134
+        if x_left is not None and x_right is not None:
135
+            dx = x_right - x_left
136
+            loss = self.loss_per_interval((x_left, x_right), self._scale, self.data)
137
+            self.losses[x_left, x_right] = loss
138
+
139
+            start = self.neighbors_combined.bisect_right(x_left)
140
+            end = self.neighbors_combined.bisect_left(x_right)
141 141
             for i in range(start, end):
142
-                a, b = self.neighbors_combined.iloc[i], self.neighbors_combined.iloc[i + 1]
143
-                self.losses_combined[a, b] = (b - a) * self.losses[x_lower, x_upper] / dx
142
+                a, b = (self.neighbors_combined.iloc[i],
143
+                        self.neighbors_combined.iloc[i + 1])
144
+                self.losses_combined[a, b] = (b - a) * loss / dx
144 145
             if start == end:
145
-                self.losses_combined[x_lower, x_upper] = self.losses[x_lower, x_upper]
146
+                self.losses_combined[x_left, x_right] = loss
146 147
 
147 148
     def update_losses(self, x, real=True):
148 149
         if real:
149
-            x_lower, x_upper = self.get_neighbors(x, self.neighbors)
150
-            self.update_interpolated_losses_in_interval(x_lower, x)
151
-            self.update_interpolated_losses_in_interval(x, x_upper)
152
-            self.losses.pop((x_lower, x_upper), None)
150
+            x_left, x_right = self.find_neighbors(x, self.neighbors)
151
+            self.update_interpolated_loss_in_interval(x_left, x)
152
+            self.update_interpolated_loss_in_interval(x, x_right)
153
+            self.losses.pop((x_left, x_right), None)
153 154
         else:
154 155
             losses_combined = self.losses_combined
155
-            x_lower, x_upper = self.get_neighbors(x, self.neighbors)
156
-            a, b = self.get_neighbors(x, self.neighbors_combined)
157
-            if x_lower is not None and x_upper is not None:
158
-                dx = x_upper - x_lower
159
-                loss = self.losses[x_lower, x_upper]
156
+            x_left, x_right = self.find_neighbors(x, self.neighbors)
157
+            a, b = self.find_neighbors(x, self.neighbors_combined)
158
+            if x_left is not None and x_right is not None:
159
+                dx = x_right - x_left
160
+                loss = self.losses[x_left, x_right]
160 161
                 losses_combined[a, x] = (x - a) * loss / dx
161 162
                 losses_combined[x, b] = (b - x) * loss / dx
162 163
             else:
... ...
@@ -167,23 +168,20 @@ class Learner1D(BaseLearner):
167 168
 
168 169
             losses_combined.pop((a, b), None)
169 170
 
170
-    def get_neighbors(self, x, neighbors):
171
+    def find_neighbors(self, x, neighbors):
171 172
         if x in neighbors:
172 173
             return neighbors[x]
173
-        return self.find_neighbors(x, neighbors)
174
-
175
-    def find_neighbors(self, x, neighbors):
176 174
         pos = neighbors.bisect_left(x)
177
-        x_lower = neighbors.iloc[pos-1] if pos != 0 else None
178
-        x_upper = neighbors.iloc[pos] if pos != len(neighbors) else None
179
-        return x_lower, x_upper
175
+        x_left = neighbors.iloc[pos-1] if pos != 0 else None
176
+        x_right = neighbors.iloc[pos] if pos != len(neighbors) else None
177
+        return x_left, x_right
180 178
 
181 179
     def update_neighbors(self, x, neighbors):
182 180
         if x not in neighbors:  # The point is new
183
-            x_lower, x_upper = self.find_neighbors(x, neighbors)
184
-            neighbors[x] = [x_lower, x_upper]
185
-            neighbors.get(x_lower, [None, None])[1] = x
186
-            neighbors.get(x_upper, [None, None])[0] = x
181
+            x_left, x_right = self.find_neighbors(x, neighbors)
182
+            neighbors[x] = [x_left, x_right]
183
+            neighbors.get(x_left, [None, None])[1] = x
184
+            neighbors.get(x_right, [None, None])[0] = x
187 185
 
188 186
     def update_scale(self, x, y):
189 187
         """Update the scale with which the x and y-values are scaled.
... ...
@@ -245,7 +243,7 @@ class Learner1D(BaseLearner):
245 243
         if self._scale[1] > self._oldscale[1] * 2:
246 244
 
247 245
             for interval in self.losses:
248
-                self.update_interpolated_losses_in_interval(*interval)
246
+                self.update_interpolated_loss_in_interval(*interval)
249 247
 
250 248
             self._oldscale = deepcopy(self._scale)
251 249
 
... ...
@@ -288,8 +286,8 @@ class Learner1D(BaseLearner):
288 286
             x_scale = self._scale[0]
289 287
 
290 288
             quals = []
291
-            for ((x0, x1), loss) in self.losses_combined.items():
292
-                dx = x1 - x0
289
+            for ((x_left, x_right), loss) in self.losses_combined.items():
290
+                dx = x_right - x_left
293 291
                 if abs(dx) < self._dx_eps:
294 292
                     # The interval is too small and should not be subdivided
295 293
                     quality = 0
... ...
@@ -297,7 +295,7 @@ class Learner1D(BaseLearner):
297 295
                     quality = -loss
298 296
                 else:
299 297
                     quality = -dx / x_scale
300
-                quals.append((quality, (x0, x1), 1))
298
+                quals.append((quality, (x_left, x_right), 1))
301 299
 
302 300
             heapq.heapify(quals)
303 301