Plot multi level Learner2D as HoloMap
See merge request qt/adaptive!62
... | ... |
@@ -259,7 +259,7 @@ class Learner2D(BaseLearner): |
259 | 259 |
self._vdim = len(value) |
260 | 260 |
except TypeError: |
261 | 261 |
self._vdim = 1 |
262 |
- return self._vdim if self._vdim is not None else 1 |
|
262 |
+ return self._vdim or 1 |
|
263 | 263 |
|
264 | 264 |
@property |
265 | 265 |
def bounds_are_done(self): |
... | ... |
@@ -385,10 +385,34 @@ class Learner2D(BaseLearner): |
385 | 385 |
self._stack[p] = np.inf |
386 | 386 |
|
387 | 387 |
def plot(self, n=None, tri_alpha=0): |
388 |
+ """Plot the Learner2D's current state. |
|
389 |
+ |
|
390 |
+ This plot function interpolates the data on a regular grid. |
|
391 |
+ The gridspacing is evaluated by checking the size of the smallest |
|
392 |
+ triangle. |
|
393 |
+ |
|
394 |
+ Parameters |
|
395 |
+ ---------- |
|
396 |
+ n : int |
|
397 |
+ Number of points in x and y. If None (default) this number is |
|
398 |
+ evaluated by looking at the size of the smallest triangle. |
|
399 |
+ tri_alpha : float |
|
400 |
+ The opacity (0 <= tri_alpha <= 1) of the triangles overlayed on |
|
401 |
+ top of the image. By default the triangulation is not visible. |
|
402 |
+ |
|
403 |
+ Returns |
|
404 |
+ ------- |
|
405 |
+ plot : holoviews.Overlay or holoviews.HoloMap |
|
406 |
+ A `holoviews.Overlay` of `holoviews.Image * holoviews.EdgePaths`. |
|
407 |
+ If the `learner.function` returns a vector output, a |
|
408 |
+ `holoviews.HoloMap` of the `holoviews.Overlay`s wil be returned. |
|
409 |
+ |
|
410 |
+ Notes |
|
411 |
+ ----- |
|
412 |
+ The plot object that is returned if `learner.function` returns a |
|
413 |
+ vector *cannot* be used with the live_plotting functionality. |
|
414 |
+ """ |
|
388 | 415 |
hv = ensure_holoviews() |
389 |
- if self.vdim > 1: |
|
390 |
- raise NotImplementedError('holoviews currently does not support', |
|
391 |
- '3D surface plots in bokeh.') |
|
392 | 416 |
x, y = self.bounds |
393 | 417 |
lbrt = x[0], y[0], x[1], y[1] |
394 | 418 |
|
... | ... |
@@ -404,7 +428,12 @@ class Learner2D(BaseLearner): |
404 | 428 |
x = y = np.linspace(-0.5, 0.5, n) |
405 | 429 |
z = ip(x[:, None], y[None, :] * self.aspect_ratio).squeeze() |
406 | 430 |
|
407 |
- im = hv.Image(np.rot90(z), bounds=lbrt) |
|
431 |
+ if self.vdim > 1: |
|
432 |
+ ims = {i: hv.Image(np.rot90(z[:, :, i]), bounds=lbrt) |
|
433 |
+ for i in range(z.shape[-1])} |
|
434 |
+ im = hv.HoloMap(ims) |
|
435 |
+ else: |
|
436 |
+ im = hv.Image(np.rot90(z), bounds=lbrt) |
|
408 | 437 |
|
409 | 438 |
if tri_alpha: |
410 | 439 |
points = self._unscale(ip.tri.points[ip.tri.vertices]) |