Browse code

1D: extrapolate the data and do local updates instead of recalculating all interpolated values

Bas Nijholt authored on 07/03/2018 23:09:28
Showing 1 changed files
... ...
@@ -6,7 +6,6 @@ import math
6 6
 
7 7
 import numpy as np
8 8
 import sortedcontainers
9
-import scipy.interpolate
10 9
 
11 10
 from ..notebook_integration import ensure_holoviews
12 11
 from .base_learner import BaseLearner
... ...
@@ -195,10 +194,28 @@ class Learner1D(BaseLearner):
195 194
 
196 195
             if self._vdim is None:
197 196
                 try:
198
-                    self._vdim = len(y)
197
+                    self._vdim = len(np.squeeze(y))
199 198
                 except TypeError:
200 199
                     self._vdim = 1
201 200
 
201
+            # Invalidate interpolated neighbors of new point
202
+            i = self.data.bisect_left(x)
203
+            if i == 0:
204
+                x_left = self.data.iloc[0]
205
+                for _x in self.data_interp:
206
+                    if _x < x_left:
207
+                        self.data_interp[_x] = None
208
+            elif i == len(self.data):
209
+                x_right = self.data.iloc[-1]
210
+                for _x in self.data_interp:
211
+                    if _x > x_right:
212
+                        self.data_interp[_x] = None
213
+            else:
214
+                x_left, x_right = self.data.iloc[i-1], self.data.iloc[i]
215
+                for _x in self.data_interp:
216
+                    if x_left < _x < x_right:
217
+                        self.data_interp[_x] = None
218
+
202 219
         else:
203 220
             # The keys of data_interp are the unknown points
204 221
             self.data_interp[x] = None
... ...
@@ -212,8 +229,21 @@ class Learner1D(BaseLearner):
212 229
         self.update_scale(x, y)
213 230
 
214 231
         # Interpolate
215
-        if not real:
216
-            self.data_interp = self.interpolate()
232
+        for _x, _y in self.data_interp.items():
233
+            if _y is None:
234
+                if len(self.data) >=2:
235
+                    i = self.data.bisect_left(_x)
236
+                    if i == 0:
237
+                        i_left, i_right = (0, 1)
238
+                    elif i == len(self.data):
239
+                        i_left, i_right = (-2, -1)
240
+                    else:
241
+                        i_left, i_right = (i - 1, i)
242
+                    x_left, x_right = self.data.iloc[i_left], self.data.iloc[i_right]
243
+                    y_left, y_right = self.data[x_left], self.data[x_right]
244
+                    dx = x_right - x_left
245
+                    dy = y_right - y_left
246
+                    self.data_interp[_x] = (dy / dx) * (_x - x_left) + y_left
217 247
 
218 248
         # Update the losses
219 249
         self.update_losses(x, self.data_combined, self.neighbors_combined,
... ...
@@ -285,41 +315,16 @@ class Learner1D(BaseLearner):
285 315
 
286 316
         return points, loss_improvements
287 317
 
288
-    def interpolate(self, extra_points=None):
289
-        xs = list(self.data.keys())
290
-        ys = list(self.data.values())
291
-        xs_unfinished = list(self.data_interp.keys())
292
-
293
-        if extra_points is not None:
294
-            xs_unfinished += extra_points
295
-
296
-        if len(xs) < 2:
297
-            interp_ys = np.zeros(len(xs_unfinished))
298
-        else:
299
-            if self.vdim > 1:
300
-                ip = scipy.interpolate.interp1d(xs, np.transpose(ys),
301
-                                                assume_sorted=True,
302
-                                                bounds_error=False,
303
-                                                fill_value=0)
304
-                interp_ys = ip(xs_unfinished).T
305
-            else:
306
-                ys = np.array(ys).flatten()  # ys could be a list of arrays with shape (1,)
307
-                interp_ys = np.interp(xs_unfinished, xs, ys)
308
-
309
-        data_interp = {x: y for x, y in zip(xs_unfinished, interp_ys)}
310
-
311
-        return data_interp
312
-
313 318
     def plot(self):
314 319
         hv = ensure_holoviews()
315 320
         if not self.data:
316 321
             return hv.Scatter([]) * hv.Path([])
317 322
 
323
+        xs = list(self.data.keys())
324
+        ys = np.array(list(self.data.values())).squeeze()
318 325
         if not self.vdim > 1:
319
-            return hv.Scatter(self.data) * hv.Path([])
326
+            return hv.Scatter((xs, ys)) * hv.Path([])
320 327
         else:
321
-            xs = list(self.data.keys())
322
-            ys = list(self.data.values())
323 328
             return hv.Path((xs, ys)) * hv.Scatter([])
324 329
 
325 330
     def remove_unfinished(self):