... | ... |
@@ -1,8 +1,10 @@ |
1 | 1 |
# -*- coding: utf-8 -*- |
2 |
+from collections import defaultdict |
|
2 | 3 |
import functools |
3 | 4 |
from operator import itemgetter |
4 | 5 |
|
5 | 6 |
from .base_learner import BaseLearner |
7 |
+from ..notebook_integration import ensure_holoviews |
|
6 | 8 |
from .utils import restore |
7 | 9 |
|
8 | 10 |
|
... | ... |
@@ -87,8 +89,56 @@ class BalancingLearner(BaseLearner): |
87 | 89 |
losses.append(loss) |
88 | 90 |
return max(losses) |
89 | 91 |
|
90 |
- def plot(self, index): |
|
91 |
- return self.learners[index].plot() |
|
92 |
+ def plot(self, cdims=None, plotter=None): |
|
93 |
+ """Returns a DynamicMap with sliders. |
|
94 |
+ |
|
95 |
+ Parameters |
|
96 |
+ ---------- |
|
97 |
+ cdims : sequence of dicts, or (keys, iterable of values), optional |
|
98 |
+ Constant dimensions; the parameters that label the learners. |
|
99 |
+ Example inputs that all give identical results: |
|
100 |
+ - sequence of dicts: |
|
101 |
+ >>> cdims = [{'A': True, 'B': 0}, |
|
102 |
+ {'A': True, 'B': 1}, |
|
103 |
+ {'A': False, 'B': 0}, |
|
104 |
+ {'A': False, 'B': 1}]` |
|
105 |
+ - tuple with (keys, iterable of values): |
|
106 |
+ >>> cdims = (['A', 'B'], itertools.product([True, False], [0, 1])) |
|
107 |
+ >>> cdims = (['A', 'B'], [(True, 0), (True, 1), |
|
108 |
+ (False, 0), (False, 1)]) |
|
109 |
+ plotter : callable, optional |
|
110 |
+ A function that takes the learner as a argument and returns a |
|
111 |
+ holoviews object. By default learner.plot() will be called. |
|
112 |
+ Returns |
|
113 |
+ ------- |
|
114 |
+ dm : holoviews.DynamicMap object |
|
115 |
+ A DynamicMap with sliders that are defined by 'cdims'. |
|
116 |
+ """ |
|
117 |
+ hv = ensure_holoviews() |
|
118 |
+ |
|
119 |
+ if cdims is None: |
|
120 |
+ cdims = [{'i': i} for i in range(len(self.learners))] |
|
121 |
+ elif not isinstance(cdims[0], dict): |
|
122 |
+ # Normalize the format |
|
123 |
+ keys, values_list = cdims |
|
124 |
+ cdims = [dict(zip(keys, values)) for values in values_list] |
|
125 |
+ |
|
126 |
+ mapping = {tuple(_cdims.values()): l for l, _cdims in zip(self.learners, cdims)} |
|
127 |
+ |
|
128 |
+ d = defaultdict(list) |
|
129 |
+ for _cdims in cdims: |
|
130 |
+ for k, v in _cdims.items(): |
|
131 |
+ d[k].append(v) |
|
132 |
+ |
|
133 |
+ def plot_function(*args): |
|
134 |
+ try: |
|
135 |
+ learner = mapping[tuple(args)] |
|
136 |
+ return learner.plot() if plotter is None else plotter(learner) |
|
137 |
+ except KeyError: |
|
138 |
+ pass |
|
139 |
+ |
|
140 |
+ dm = hv.DynamicMap(plot_function, kdims=list(d.keys())) |
|
141 |
+ return dm.redim.values(**d) |
|
92 | 142 |
|
93 | 143 |
def remove_unfinished(self): |
94 | 144 |
"""Remove uncomputed data from the learners.""" |