Browse code

set 'scatter_or_line' by default to 'scatter'

Bas Nijholt authored on 14/10/2019 13:12:27
Showing 1 changed files
... ...
@@ -580,33 +580,32 @@ class Learner1D(BaseLearner):
580 580
         loss = mapping[ival]
581 581
         return finite_loss(ival, loss, self._scale[0])
582 582
 
583
-    def plot(self, *, scatter_or_line=None):
583
+    def plot(self, *, scatter_or_line="scatter"):
584 584
         """Returns a plot of the evaluated data.
585 585
 
586
+        Parameters
587
+        ----------
588
+        scatter_or_line : str, default: "scatter"
589
+            Plot as a scatter plot ("scatter") or a line plot ("line").
590
+
586 591
         Returns
587 592
         -------
588
-        plot : `holoviews.element.Scatter` (if vdim=1)\
589
-               else `holoviews.element.Path`
593
+        plot : `holoviews.Scatter`, `holoviews.Path`, or `holoviews.Overlay`
590 594
             Plot of the evaluated data.
591
-        scatter_or_line : str, optional
592
-            Plot as a scatter plot ("scatter") or a line plot ("line").
593
-            By default a line plot will be chosen if the data consists of
594
-            vectors, otherwise it is a scatter plot.
595 595
         """
596
+        if scatter_or_line not in ("scatter", "line"):
597
+            raise ValueError("scatter_or_line must be 'scatter' or 'line'")
596 598
         hv = ensure_holoviews()
597 599
 
598 600
         xs, ys = zip(*sorted(self.data.items())) if self.data else ([], [])
599
-        if self.vdim == 1:
600
-            if scatter_or_line is None or scatter_or_line == "scatter":
601
-                p = hv.Path([]) * hv.Scatter((xs, ys))
602
-            else:
603
-                p = hv.Path((xs, ys)) * hv.Scatter()
604
-        else:
605
-            if scatter_or_line is None or scatter_or_line == "line":
606
-                p = hv.Path((xs, ys)) * hv.Scatter([])
601
+        if scatter_or_line == "scatter":
602
+            if self.vdim == 1:
603
+                p = hv.Scatter((xs, ys))
607 604
             else:
608 605
                 scatters = [hv.Scatter((xs, _ys)) for _ys in np.transpose(ys)]
609
-                p = hv.Path([]) * hv.Overlay(scatters)
606
+                p = hv.Overlay(scatters)
607
+        else:
608
+            p = hv.Path((xs, ys))
610 609
 
611 610
         # Plot with 5% empty margins such that the boundary points are visible
612 611
         margin = 0.05 * (self.bounds[1] - self.bounds[0])