Browse code

Merge branch 'learner2D_plotting' into 'master'

Plot multi level Learner2D as HoloMap

See merge request qt/adaptive!62

Joseph Weston authored on 18/07/2018 19:38:30
Showing 1 changed files
... ...
@@ -259,7 +259,7 @@ class Learner2D(BaseLearner):
259 259
                 self._vdim = len(value)
260 260
             except TypeError:
261 261
                 self._vdim = 1
262
-        return self._vdim if self._vdim is not None else 1
262
+        return self._vdim or 1
263 263
 
264 264
     @property
265 265
     def bounds_are_done(self):
... ...
@@ -385,10 +385,34 @@ class Learner2D(BaseLearner):
385 385
                 self._stack[p] = np.inf
386 386
 
387 387
     def plot(self, n=None, tri_alpha=0):
388
+        """Plot the Learner2D's current state.
389
+
390
+        This plot function interpolates the data on a regular grid.
391
+        The gridspacing is evaluated by checking the size of the smallest
392
+        triangle.
393
+
394
+        Parameters
395
+        ----------
396
+        n : int
397
+            Number of points in x and y. If None (default) this number is
398
+            evaluated by looking at the size of the smallest triangle.
399
+        tri_alpha : float
400
+            The opacity (0 <= tri_alpha <= 1) of the triangles overlayed on
401
+            top of the image. By default the triangulation is not visible.
402
+
403
+        Returns
404
+        -------
405
+        plot : holoviews.Overlay or holoviews.HoloMap
406
+            A `holoviews.Overlay` of `holoviews.Image * holoviews.EdgePaths`.
407
+            If the `learner.function` returns a vector output, a
408
+            `holoviews.HoloMap` of the `holoviews.Overlay`s wil be returned.
409
+
410
+        Notes
411
+        -----
412
+        The plot object that is returned if `learner.function` returns a
413
+        vector *cannot* be used with the live_plotting functionality.
414
+        """
388 415
         hv = ensure_holoviews()
389
-        if self.vdim > 1:
390
-            raise NotImplementedError('holoviews currently does not support',
391
-                                      '3D surface plots in bokeh.')
392 416
         x, y = self.bounds
393 417
         lbrt = x[0], y[0], x[1], y[1]
394 418
 
... ...
@@ -404,7 +428,12 @@ class Learner2D(BaseLearner):
404 428
             x = y = np.linspace(-0.5, 0.5, n)
405 429
             z = ip(x[:, None], y[None, :] * self.aspect_ratio).squeeze()
406 430
 
407
-            im = hv.Image(np.rot90(z), bounds=lbrt)
431
+            if self.vdim > 1:
432
+                ims = {i: hv.Image(np.rot90(z[:, :, i]), bounds=lbrt)
433
+                       for i in range(z.shape[-1])}
434
+                im = hv.HoloMap(ims)
435
+            else:
436
+                im = hv.Image(np.rot90(z), bounds=lbrt)
408 437
 
409 438
             if tri_alpha:
410 439
                 points = self._unscale(ip.tri.points[ip.tri.vertices])