Browse code

fix that plotting always works

Bas Nijholt authored on 14/11/2017 15:23:22
Showing 2 changed files
... ...
@@ -128,7 +128,7 @@ class Learner1D(BaseLearner):
128 128
                 pass
129 129
 
130 130
             if self.vector_output is None:
131
-                self.vector_output = hasattr(y, '__len__')
131
+                self.vector_output = hasattr(y, '__len__') and len(y) > 1
132 132
 
133 133
         else:
134 134
             # The keys of data_interp are the unknown points
... ...
@@ -234,6 +234,7 @@ class Learner1D(BaseLearner):
234 234
                                                 fill_value=0)
235 235
                 interp_ys = ip(xs_unfinished).T
236 236
             else:
237
+                ys = np.array(ys).flatten()  # ys could be a list of arrays with shape (1,)
237 238
                 interp_ys = np.interp(xs_unfinished, xs, ys)
238 239
 
239 240
         data_interp = {x: y for x, y in zip(xs_unfinished, interp_ys)}
... ...
@@ -241,18 +242,15 @@ class Learner1D(BaseLearner):
241 242
         return data_interp
242 243
 
243 244
     def plot(self):
245
+        if not self.data:
246
+            return hv.Scatter([]) * hv.Path([])
247
+
244 248
         if not self.vector_output:
245
-            if self.data:
246
-                return hv.Scatter(self.data)
247
-            else:
248
-                return hv.Scatter([])
249
+            return hv.Scatter(self.data) * hv.Path([])
249 250
         else:
250
-            if self.data:
251
-                xs = list(self.data.keys())
252
-                ys = list(self.data.values())
253
-                return hv.Path((xs, ys))
254
-            else:
255
-                return hv.Path([])
251
+            xs = list(self.data.keys())
252
+            ys = list(self.data.values())
253
+            return hv.Path((xs, ys)) * hv.Scatter([])
256 254
 
257 255
     def remove_unfinished(self):
258 256
         self.data_interp = {}
... ...
@@ -393,7 +393,7 @@
393 393
    "cell_type": "markdown",
394 394
    "metadata": {},
395 395
    "source": [
396
-    "This is again a function with sharp peaks at different x-values and with different constant backgrounds. To learn this function we can use a `Learner1D` with the argument `vector_output=True`."
396
+    "This is again a function with sharp peaks at different x-values and with different constant backgrounds. To learn this function we can use a `Learner1D` as well."
397 397
    ]
398 398
   },
399 399
   {
... ...
@@ -404,7 +404,7 @@
404 404
    "source": [
405 405
     "from adaptive.runner import SequentialExecutor\n",
406 406
     "\n",
407
-    "learner = adaptive.Learner1D(f_levels, bounds=(-1, 1), vector_output=True)\n",
407
+    "learner = adaptive.Learner1D(f_levels, bounds=(-1, 1))\n",
408 408
     "runner = adaptive.Runner(learner, SequentialExecutor(), goal=lambda l: l.loss() < 0.05)"
409 409
    ]
410 410
   },