add scatter_or_line argument to Learner1D.plot
Joseph Weston authored on 14/10/2019 15:31:52 • GitHub committed on 14/10/2019 15:31:52... | ... |
@@ -580,23 +580,35 @@ class Learner1D(BaseLearner): |
580 | 580 |
loss = mapping[ival] |
581 | 581 |
return finite_loss(ival, loss, self._scale[0]) |
582 | 582 |
|
583 |
- def plot(self): |
|
583 |
+ def plot(self, *, scatter_or_line="scatter"): |
|
584 | 584 |
"""Returns a plot of the evaluated data. |
585 | 585 |
|
586 |
+ Parameters |
|
587 |
+ ---------- |
|
588 |
+ scatter_or_line : str, default: "scatter" |
|
589 |
+ Plot as a scatter plot ("scatter") or a line plot ("line"). |
|
590 |
+ |
|
586 | 591 |
Returns |
587 | 592 |
------- |
588 |
- plot : `holoviews.element.Scatter` (if vdim=1)\ |
|
589 |
- else `holoviews.element.Path` |
|
593 |
+ plot : `holoviews.Overlay` |
|
590 | 594 |
Plot of the evaluated data. |
591 | 595 |
""" |
596 |
+ if scatter_or_line not in ("scatter", "line"): |
|
597 |
+ raise ValueError("scatter_or_line must be 'scatter' or 'line'") |
|
592 | 598 |
hv = ensure_holoviews() |
593 | 599 |
|
594 | 600 |
xs, ys = zip(*sorted(self.data.items())) if self.data else ([], []) |
595 |
- if self.vdim == 1: |
|
596 |
- p = hv.Path([]) * hv.Scatter((xs, ys)) |
|
601 |
+ if scatter_or_line == "scatter": |
|
602 |
+ if self.vdim == 1: |
|
603 |
+ plots = [hv.Scatter((xs, ys))] |
|
604 |
+ else: |
|
605 |
+ plots = [hv.Scatter((xs, _ys)) for _ys in np.transpose(ys)] |
|
597 | 606 |
else: |
598 |
- p = hv.Path((xs, ys)) * hv.Scatter([]) |
|
607 |
+ plots = [hv.Path((xs, ys))] |
|
599 | 608 |
|
609 |
+ # Put all plots in an Overlay because a DynamicMap can't handle changing |
|
610 |
+ # datatypes, e.g. when `vdim` isn't yet known and the live_plot is running. |
|
611 |
+ p = hv.Overlay(plots) |
|
600 | 612 |
# Plot with 5% empty margins such that the boundary points are visible |
601 | 613 |
margin = 0.05 * (self.bounds[1] - self.bounds[0]) |
602 | 614 |
plot_bounds = (self.bounds[0] - margin, self.bounds[1] + margin) |