Browse code

modify BalancingLearner.plot() to return a DynamicMap with sliders

Bas Nijholt authored on 07/03/2018 07:00:09 • Joseph Weston committed on 19/03/2018 16:44:56
Showing 1 changed files
... ...
@@ -1,8 +1,10 @@
1 1
 # -*- coding: utf-8 -*-
2
+from collections import defaultdict
2 3
 import functools
3 4
 from operator import itemgetter
4 5
 
5 6
 from .base_learner import BaseLearner
7
+from ..notebook_integration import ensure_holoviews
6 8
 from .utils import restore
7 9
 
8 10
 
... ...
@@ -87,8 +89,56 @@ class BalancingLearner(BaseLearner):
87 89
             losses.append(loss)
88 90
         return max(losses)
89 91
 
90
-    def plot(self, index):
91
-        return self.learners[index].plot()
92
+    def plot(self, cdims=None, plotter=None):
93
+        """Returns a DynamicMap with sliders.
94
+
95
+        Parameters
96
+        ----------
97
+        cdims : sequence of dicts, or (keys, iterable of values), optional
98
+            Constant dimensions; the parameters that label the learners.
99
+            Example inputs that all give identical results:
100
+            - sequence of dicts:
101
+                >>> cdims = [{'A': True, 'B': 0},
102
+                             {'A': True, 'B': 1},
103
+                             {'A': False, 'B': 0},
104
+                             {'A': False, 'B': 1}]`
105
+            - tuple with (keys, iterable of values):
106
+                >>> cdims = (['A', 'B'], itertools.product([True, False], [0, 1]))
107
+                >>> cdims = (['A', 'B'], [(True, 0), (True, 1),
108
+                                          (False, 0), (False, 1)])
109
+        plotter : callable, optional
110
+            A function that takes the learner as a argument and returns a
111
+            holoviews object. By default learner.plot() will be called.
112
+        Returns
113
+        -------
114
+        dm : holoviews.DynamicMap object
115
+            A DynamicMap with sliders that are defined by 'cdims'.
116
+        """
117
+        hv = ensure_holoviews()
118
+
119
+        if cdims is None:
120
+            cdims = [{'i': i} for i in range(len(self.learners))]
121
+        elif not isinstance(cdims[0], dict):
122
+            # Normalize the format
123
+            keys, values_list = cdims
124
+            cdims = [dict(zip(keys, values)) for values in values_list]
125
+
126
+        mapping = {tuple(_cdims.values()): l for l, _cdims in zip(self.learners, cdims)}
127
+
128
+        d = defaultdict(list)
129
+        for _cdims in cdims:
130
+            for k, v in _cdims.items():
131
+                d[k].append(v)
132
+
133
+        def plot_function(*args):
134
+            try:
135
+                learner = mapping[tuple(args)]
136
+                return learner.plot() if plotter is None else plotter(learner)
137
+            except KeyError:
138
+                pass
139
+
140
+        dm = hv.DynamicMap(plot_function, kdims=list(d.keys()))
141
+        return dm.redim.values(**d)
92 142
 
93 143
     def remove_unfinished(self):
94 144
         """Remove uncomputed data from the learners."""