Browse code

Merge pull request #215 from python-adaptive/learner1D_plot

add scatter_or_line argument to Learner1D.plot

Joseph Weston authored on 14/10/2019 15:31:52 • GitHub committed on 14/10/2019 15:31:52
Showing 1 changed files
... ...
@@ -580,23 +580,35 @@ class Learner1D(BaseLearner):
580 580
         loss = mapping[ival]
581 581
         return finite_loss(ival, loss, self._scale[0])
582 582
 
583
-    def plot(self):
583
+    def plot(self, *, scatter_or_line="scatter"):
584 584
         """Returns a plot of the evaluated data.
585 585
 
586
+        Parameters
587
+        ----------
588
+        scatter_or_line : str, default: "scatter"
589
+            Plot as a scatter plot ("scatter") or a line plot ("line").
590
+
586 591
         Returns
587 592
         -------
588
-        plot : `holoviews.element.Scatter` (if vdim=1)\
589
-               else `holoviews.element.Path`
593
+        plot : `holoviews.Overlay`
590 594
             Plot of the evaluated data.
591 595
         """
596
+        if scatter_or_line not in ("scatter", "line"):
597
+            raise ValueError("scatter_or_line must be 'scatter' or 'line'")
592 598
         hv = ensure_holoviews()
593 599
 
594 600
         xs, ys = zip(*sorted(self.data.items())) if self.data else ([], [])
595
-        if self.vdim == 1:
596
-            p = hv.Path([]) * hv.Scatter((xs, ys))
601
+        if scatter_or_line == "scatter":
602
+            if self.vdim == 1:
603
+                plots = [hv.Scatter((xs, ys))]
604
+            else:
605
+                plots = [hv.Scatter((xs, _ys)) for _ys in np.transpose(ys)]
597 606
         else:
598
-            p = hv.Path((xs, ys)) * hv.Scatter([])
607
+            plots = [hv.Path((xs, ys))]
599 608
 
609
+        # Put all plots in an Overlay because a DynamicMap can't handle changing
610
+        # datatypes, e.g. when `vdim` isn't yet known and the live_plot is running.
611
+        p = hv.Overlay(plots)
600 612
         # Plot with 5% empty margins such that the boundary points are visible
601 613
         margin = 0.05 * (self.bounds[1] - self.bounds[0])
602 614
         plot_bounds = (self.bounds[0] - margin, self.bounds[1] + margin)