Browse code

add scalarplot() implementation for plotly backend

Kelvin Loh authored on 06/12/2018 11:15:36
Showing 1 changed files
... ...
@@ -2294,18 +2294,66 @@ def scalarplot(field, box,
2294 2294
     fig : matplotlib figure
2295 2295
         A figure with the output if ``ax`` is not set, else None.
2296 2296
     """
2297
-    if not _p.mpl_available:
2298
-        raise RuntimeError("matplotlib was not found, but is required "
2299
-                           "for current()")
2300 2297
 
2301 2298
     # Matplotlib plots images like matrices: image[y, x].  We use the opposite
2302 2299
     # convention: image[x, y].  Hence, it is necessary to transpose.
2303 2300
     # Also squeeze out the last axis as it is just a scalar field
2301
+
2304 2302
     field = field.squeeze(axis=-1).transpose()
2305 2303
 
2306 2304
     if field.ndim != 2:
2307 2305
         raise ValueError("Only 2D field can be plotted.")
2308 2306
 
2307
+    if vmin is None:
2308
+        vmin = np.min(field)
2309
+    if vmax is None:
2310
+        vmax = np.max(field)
2311
+
2312
+    if get_backend() == _p.Backends.matplotlib:
2313
+        fig = _scalarplot_matplotlib(field, box, cmap, colorbar,
2314
+                                     file, show, dpi, fig_size, ax,
2315
+                                     vmin, vmax, background)
2316
+    elif get_backend() == _p.Backends.plotly:
2317
+        _check_incompatible_args_plotly(dpi, fig_size, ax)
2318
+        fig = _scalarplot_plotly(field, box, cmap, colorbar, file,
2319
+                                 show, vmin, vmax, background)
2320
+    _maybe_output_fig(fig, file=file, show=show)
2321
+
2322
+    return fig
2323
+
2324
+
2325
+def _scalarplot_plotly(field, box, cmap, colorbar, file,
2326
+                       show, vmin, vmax, background):
2327
+    if not _p.plotly_available:
2328
+        raise RuntimeError("plotly was not found, but is required "
2329
+                           "for scalarplot()")
2330
+
2331
+    if cmap is None:
2332
+        cmap = _p.kwant_red_plotly
2333
+
2334
+    contour_object = _p.plotly_graph_objs.Heatmap()
2335
+    contour_object.z = field
2336
+    contour_object.x = np.linspace(*box[0],field.shape[0])
2337
+    contour_object.y = np.linspace(*box[1],field.shape[1])
2338
+    contour_object.zsmooth = 'best'
2339
+    contour_object.colorscale = cmap
2340
+    contour_object.zmax = vmax
2341
+    contour_object.zmin = vmin
2342
+
2343
+    contour_object.showscale = colorbar
2344
+
2345
+    fig = _p.plotly_graph_objs.Figure(data=[contour_object])
2346
+    fig.layout.plot_bgcolor = background
2347
+
2348
+    return fig
2349
+
2350
+
2351
+def _scalarplot_matplotlib(field, box, cmap, colorbar, file, show, dpi,
2352
+                           fig_size, ax, vmin, vmax, background):
2353
+    if not _p.mpl_available:
2354
+        raise RuntimeError("matplotlib was not found, but is required "
2355
+                           "for scalarplot()")
2356
+
2309 2357
     if cmap is None:
2310 2358
         cmap = _p.kwant_red_matplotlib
2311 2359
     cmap = _p.matplotlib.cm.get_cmap(cmap)
... ...
@@ -2316,11 +2364,6 @@ def scalarplot(field, box,
2316 2364
     else:
2317 2365
         fig = None
2318 2366
 
2319
-    if vmin is None:
2320
-        vmin = np.min(field)
2321
-    if vmax is None:
2322
-        vmax = np.max(field)
2323
-
2324 2367
     image = ax.imshow(field, cmap=cmap,
2325 2368
                       interpolation='bicubic',
2326 2369
                       extent=[e for c in box for e in c],
... ...
@@ -2333,8 +2376,6 @@ def scalarplot(field, box,
2333 2376
     if colorbar and cmap and fig is not None:
2334 2377
         fig.colorbar(image)
2335 2378
 
2336
-    _maybe_output_fig(fig, file=file, show=show)
2337
-
2338 2379
     return fig
2339 2380
 
2340 2381