Browse code

refactor almost duplicated code into new function: update_interpolated_losses_in_interval

Jorn Hoofwijk authored on 27/04/2018 19:29:36
Showing 1 changed files
... ...
@@ -131,30 +131,23 @@ class Learner1D(BaseLearner):
131 131
         else:
132 132
             return max(losses.values())
133 133
 
134
+    def update_interpolated_losses_in_interval(self, x_lower, x_upper):
135
+        if x_lower is not None and x_upper is not None:
136
+            self.losses[x_lower, x_upper] = self.loss_per_interval((x_lower, x_upper),
137
+                                                             self._scale, self.data)
138
+            start = self.neighbors_combined.bisect_right(x_lower)
139
+            end = self.neighbors_combined.bisect_left(x_upper)
140
+            for i in range(start, end):
141
+                a, b = self.neighbors_combined.iloc[i], self.neighbors_combined.iloc[i + 1]
142
+                self.losses_combined[a, b] = (b - a) * self.losses[x_lower, x_upper] / (x_upper - x_lower)
143
+            if start == end:
144
+                self.losses_combined[x_lower, x_upper] = self.losses[x_lower, x_upper]
145
+
134 146
     def update_losses(self, x, real=True):
135 147
         if real:
136 148
             x_lower, x_upper = self.get_neighbors(x, self.neighbors)
137
-            if x_lower is not None:
138
-                self.losses[x_lower, x] = self.loss_per_interval((x_lower, x),
139
-                                                                 self._scale, self.data)
140
-                start = self.neighbors_combined.bisect_right(x_lower)
141
-                end = self.neighbors_combined.bisect_left(x)
142
-                for i in range(start, end):
143
-                    a, b = self.neighbors_combined.iloc[i], self.neighbors_combined.iloc[i + 1]
144
-                    self.losses_combined[a, b] = (b - a) * self.losses[x_lower, x] / (x - x_lower)
145
-                if start == end:
146
-                    self.losses_combined[x_lower, x] = self.losses[x_lower, x]
147
-
148
-            if x_upper is not None:
149
-                self.losses[x, x_upper] = self.loss_per_interval((x, x_upper),
150
-                                                                 self._scale, self.data)
151
-                start = self.neighbors_combined.bisect_right(x)
152
-                end = self.neighbors_combined.bisect_left(x_upper)
153
-                for i in range(start, end):
154
-                    a, b = self.neighbors_combined.iloc[i], self.neighbors_combined.iloc[i + 1]
155
-                    self.losses_combined[a, b] = (b - a) * self.losses[x, x_upper] / (x_upper - x)
156
-                if start == end:
157
-                    self.losses_combined[x, x_upper] = self.losses[x, x_upper]
149
+            self.update_interpolated_losses_in_interval(x_lower, x)
150
+            self.update_interpolated_losses_in_interval(x, x_upper)
158 151
 
159 152
             try:
160 153
                 del self.losses[x_lower, x_upper]