Browse code

remove 'set_loss_if_interval_is_greater_than_dx_eps' and '_update_loss' functions

Bas Nijholt authored on 03/05/2018 14:01:30
Showing 1 changed files
... ...
@@ -135,29 +135,22 @@ class Learner1D(BaseLearner):
135 135
         else:
136 136
             return max(losses.values())
137 137
 
138
-    def _update_loss(self, interval):
139
-        value = self.loss_per_interval(interval, self._scale, self.data)
140
-        self.set_loss_if_interval_is_greater_than_dx_eps(interval, value, self.losses)
141
-
142
-    def set_loss_if_interval_is_greater_than_dx_eps(self, interval, value, losses):
143
-        a, b = interval
144
-        if abs(a - b) > self._dx_eps:
145
-            losses[interval] = value
146
-        else:
147
-            losses[interval] = 0
148 138
 
149 139
     def update_interpolated_losses_in_interval(self, x_lower, x_upper):
150 140
         if x_lower is not None and x_upper is not None:
151
-            self._update_loss((x_lower, x_upper))
141
+            dx = x_upper - x_lower
142
+            loss = self.loss_per_interval((x_lower, x_upper), self._scale, self.data)
143
+            self.losses[x_lower, x_upper] = loss if abs(dx) > self._dx_eps else 0
152 144
 
153 145
             start = self.neighbors_combined.bisect_right(x_lower)
154 146
             end = self.neighbors_combined.bisect_left(x_upper)
155 147
             for i in range(start, end):
156 148
                 a, b = self.neighbors_combined.iloc[i], self.neighbors_combined.iloc[i + 1]
157
-                self.losses_combined[a, b] = (b - a) * self.losses[x_lower, x_upper] / (x_upper - x_lower)
149
+                self.losses_combined[a, b] = (b - a) * self.losses[x_lower, x_upper] / dx
158 150
             if start == end:
159 151
                 self.losses_combined[x_lower, x_upper] = self.losses[x_lower, x_upper]
160 152
 
153
+
161 154
     def update_losses(self, x, real=True):
162 155
         if real:
163 156
             x_lower, x_upper = self.get_neighbors(x, self.neighbors)
... ...
@@ -169,26 +162,24 @@ class Learner1D(BaseLearner):
169 162
             except KeyError:
170 163
                 pass
171 164
         else:
165
+            losses_combined = self.losses_combined
172 166
             x_lower, x_upper = self.get_neighbors(x, self.neighbors)
167
+            dx = x_upper - x_lower
173 168
             a, b = self.get_neighbors(x, self.neighbors_combined)
174 169
             if x_lower is not None and x_upper is not None:
175
-                val = (x - a) * self.losses[x_lower, x_upper]\
176
-                      / (x_upper - x_lower)
177
-                self.set_loss_if_interval_is_greater_than_dx_eps((a, x),
178
-                                                    val, self.losses_combined)
179
-
180
-                val = (b - x) * self.losses[x_lower, x_upper] \
181
-                      / (x_upper - x_lower)
182
-                self.set_loss_if_interval_is_greater_than_dx_eps((x, b),
183
-                                                    val, self.losses_combined)
170
+                loss = self.losses[x_lower, x_upper]
171
+                losses_combined[a, x] = ((x - a) * loss / dx
172
+                                         if abs(x - a) > self._dx_eps else 0)
173
+                losses_combined[x, b] = ((b - x) * loss  / dx
174
+                                         if abs(b - x) > self._dx_eps else 0)
184 175
             else:
185 176
                 if a is not None:
186
-                    self.losses_combined[a, x] = float('inf')
177
+                    losses_combined[a, x] = float('inf')
187 178
                 if b is not None:
188
-                    self.losses_combined[x, b] = float('inf')
179
+                    losses_combined[x, b] = float('inf')
189 180
 
190 181
             try:
191
-                del self.losses_combined[a, b]
182
+                del losses_combined[a, b]
192 183
             except KeyError:
193 184
                 pass
194 185