... | ... |
@@ -6,7 +6,6 @@ import math |
6 | 6 |
|
7 | 7 |
import numpy as np |
8 | 8 |
import sortedcontainers |
9 |
-import scipy.interpolate |
|
10 | 9 |
|
11 | 10 |
from ..notebook_integration import ensure_holoviews |
12 | 11 |
from .base_learner import BaseLearner |
... | ... |
@@ -195,10 +194,28 @@ class Learner1D(BaseLearner): |
195 | 194 |
|
196 | 195 |
if self._vdim is None: |
197 | 196 |
try: |
198 |
- self._vdim = len(y) |
|
197 |
+ self._vdim = len(np.squeeze(y)) |
|
199 | 198 |
except TypeError: |
200 | 199 |
self._vdim = 1 |
201 | 200 |
|
201 |
+ # Invalidate interpolated neighbors of new point |
|
202 |
+ i = self.data.bisect_left(x) |
|
203 |
+ if i == 0: |
|
204 |
+ x_left = self.data.iloc[0] |
|
205 |
+ for _x in self.data_interp: |
|
206 |
+ if _x < x_left: |
|
207 |
+ self.data_interp[_x] = None |
|
208 |
+ elif i == len(self.data): |
|
209 |
+ x_right = self.data.iloc[-1] |
|
210 |
+ for _x in self.data_interp: |
|
211 |
+ if _x > x_right: |
|
212 |
+ self.data_interp[_x] = None |
|
213 |
+ else: |
|
214 |
+ x_left, x_right = self.data.iloc[i-1], self.data.iloc[i] |
|
215 |
+ for _x in self.data_interp: |
|
216 |
+ if x_left < _x < x_right: |
|
217 |
+ self.data_interp[_x] = None |
|
218 |
+ |
|
202 | 219 |
else: |
203 | 220 |
# The keys of data_interp are the unknown points |
204 | 221 |
self.data_interp[x] = None |
... | ... |
@@ -212,8 +229,21 @@ class Learner1D(BaseLearner): |
212 | 229 |
self.update_scale(x, y) |
213 | 230 |
|
214 | 231 |
# Interpolate |
215 |
- if not real: |
|
216 |
- self.data_interp = self.interpolate() |
|
232 |
+ for _x, _y in self.data_interp.items(): |
|
233 |
+ if _y is None: |
|
234 |
+ if len(self.data) >=2: |
|
235 |
+ i = self.data.bisect_left(_x) |
|
236 |
+ if i == 0: |
|
237 |
+ i_left, i_right = (0, 1) |
|
238 |
+ elif i == len(self.data): |
|
239 |
+ i_left, i_right = (-2, -1) |
|
240 |
+ else: |
|
241 |
+ i_left, i_right = (i - 1, i) |
|
242 |
+ x_left, x_right = self.data.iloc[i_left], self.data.iloc[i_right] |
|
243 |
+ y_left, y_right = self.data[x_left], self.data[x_right] |
|
244 |
+ dx = x_right - x_left |
|
245 |
+ dy = y_right - y_left |
|
246 |
+ self.data_interp[_x] = (dy / dx) * (_x - x_left) + y_left |
|
217 | 247 |
|
218 | 248 |
# Update the losses |
219 | 249 |
self.update_losses(x, self.data_combined, self.neighbors_combined, |
... | ... |
@@ -285,41 +315,16 @@ class Learner1D(BaseLearner): |
285 | 315 |
|
286 | 316 |
return points, loss_improvements |
287 | 317 |
|
288 |
- def interpolate(self, extra_points=None): |
|
289 |
- xs = list(self.data.keys()) |
|
290 |
- ys = list(self.data.values()) |
|
291 |
- xs_unfinished = list(self.data_interp.keys()) |
|
292 |
- |
|
293 |
- if extra_points is not None: |
|
294 |
- xs_unfinished += extra_points |
|
295 |
- |
|
296 |
- if len(xs) < 2: |
|
297 |
- interp_ys = np.zeros(len(xs_unfinished)) |
|
298 |
- else: |
|
299 |
- if self.vdim > 1: |
|
300 |
- ip = scipy.interpolate.interp1d(xs, np.transpose(ys), |
|
301 |
- assume_sorted=True, |
|
302 |
- bounds_error=False, |
|
303 |
- fill_value=0) |
|
304 |
- interp_ys = ip(xs_unfinished).T |
|
305 |
- else: |
|
306 |
- ys = np.array(ys).flatten() # ys could be a list of arrays with shape (1,) |
|
307 |
- interp_ys = np.interp(xs_unfinished, xs, ys) |
|
308 |
- |
|
309 |
- data_interp = {x: y for x, y in zip(xs_unfinished, interp_ys)} |
|
310 |
- |
|
311 |
- return data_interp |
|
312 |
- |
|
313 | 318 |
def plot(self): |
314 | 319 |
hv = ensure_holoviews() |
315 | 320 |
if not self.data: |
316 | 321 |
return hv.Scatter([]) * hv.Path([]) |
317 | 322 |
|
323 |
+ xs = list(self.data.keys()) |
|
324 |
+ ys = np.array(list(self.data.values())).squeeze() |
|
318 | 325 |
if not self.vdim > 1: |
319 |
- return hv.Scatter(self.data) * hv.Path([]) |
|
326 |
+ return hv.Scatter((xs, ys)) * hv.Path([]) |
|
320 | 327 |
else: |
321 |
- xs = list(self.data.keys()) |
|
322 |
- ys = list(self.data.values()) |
|
323 | 328 |
return hv.Path((xs, ys)) * hv.Scatter([]) |
324 | 329 |
|
325 | 330 |
def remove_unfinished(self): |