Browse code

implement plotting backend selection

Kelvin Loh authored on 06/12/2018 10:21:46
Showing 3 changed files
... ...
@@ -18,6 +18,7 @@
18 18
 import warnings
19 19
 from math import sqrt, pi
20 20
 import numpy as np
21
+from enum import Enum
21 22
 
22 23
 try:
23 24
     import matplotlib
... ...
@@ -34,11 +35,34 @@ try:
34 35
         warnings.warn("3D plotting not available.", RuntimeWarning)
35 36
         has3d = False
36 37
 except ImportError:
37
-    warnings.warn("matplotlib is not available, only iterator-providing "
38
-                  "functions will work.", RuntimeWarning)
38
+    warnings.warn("matplotlib is not available, if other backends are "
39
+                  "unavailable, only iterator-providing functions will work",
40
+                  RuntimeWarning)
39 41
     mpl_available = False
40 42
 
41 43
 
44
+try:
45
+    import plotly.offline as plotly
46
+    import plotly.graph_objs as plotly_graph_objs
47
+    plotly.init_notebook_mode(connected=True)
48
+    plotly_available = True
49
+except ImportError:
50
+    warnings.warn("plotly is not available, if other backends are unavailable,"
51
+                  " only iterator-providing functions will work",
52
+                  RuntimeWarning)
53
+    plotly_available = False
54
+
55
+
56
+class Backends(Enum):
57
+    matplotlib = 0
58
+    plotly = 1
59
+
60
+backend = Backends.matplotlib
61
+
62
+if not ((mpl_available) or (plotly_available)):
63
+    backend = None
64
+
65
+
42 66
 # Collections that allow for symbols and linewiths to be given in data space
43 67
 # (not for general use, only implement what's needed for plotter)
44 68
 def isarray(var):
... ...
@@ -30,7 +30,8 @@ from . import system, builder, _common
30 30
 from ._common import deprecate_args
31 31
 
32 32
 
33
-__all__ = ['plot', 'map', 'bands', 'spectrum', 'current', 'density',
33
+__all__ = ['set_backend',
34
+           'plot', 'map', 'bands', 'spectrum', 'current', 'density',
34 35
            'interpolate_current', 'interpolate_density',
35 36
            'streamplot', 'scalarplot',
36 37
            'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos',
... ...
@@ -41,6 +42,43 @@ __all__ = ['plot', 'map', 'bands', 'spectrum', 'current', 'density',
41 42
 _p = _common.lazy_import('_plotter')
42 43
 
43 44
 
45
+def set_backend(backend):
46
+    """Set the plotting backend to use.
47
+
48
+    Parameters
49
+    ----------
50
+    backend : str
51
+        Options are: 'matplotlib', 'plotly'.
52
+    """
53
+
54
+    if ((_p.mpl_available) or (_p.plotly_available)):
55
+        try:
56
+            _p.backend = _p.Backends[backend]
57
+        except KeyError:
58
+            list_available_backends = []
59
+            if (_p.mpl_available):
60
+                list_available_backends.append('matplotlib')
61
+            if (_p.plotly_available):
62
+                list_available_backends.append('plotly')
63
+            error_message = "Tried to set an unknown backend \'{}\'.".format(
64
+                                                                       backend)
65
+            error_message += " Supported backends are {}".format(
66
+                                                       list_available_backends)
67
+            raise RuntimeError(error_message)
68
+    else:
69
+        warnings.warn("Tried to set \'{}\' but is not "
70
+                      "available.".format(backend), RuntimeWarning)
71
+
72
+    if (_p.is_ipython_kernel):
73
+        if ((_p.backend == _p.Backends.plotly) and
74
+            (not _p.init_notebook_mode_set)):
75
+            _p.init_notebook_mode_set = True
76
+            _p.plotly_module.init_notebook_mode(connected=True)
77
+
78
+
79
+def get_backend():
80
+    return _p.backend
81
+
44 82
 def _sample_array(array, n_samples, rng=None):
45 83
     rng = _common.ensure_rng(rng)
46 84
     la = len(array)
... ...
@@ -106,13 +144,19 @@ def _maybe_output_fig(fig, file=None, show=True):
106 144
     if fig is None:
107 145
         return
108 146
 
109
-    if file is not None:
110
-        fig.canvas.print_figure(file, dpi=fig.dpi)
111
-    elif show:
112
-        # If there was no file provided, pyplot should already be available and
113
-        # we can import it safely without additional warnings.
114
-        from matplotlib import pyplot
115
-        pyplot.show()
147
+    if get_backend() == _p.Backends.matplotlib:
148
+        if file is not None:
149
+            fig.canvas.print_figure(file, dpi=fig.dpi)
150
+        elif show:
151
+            # If there was no file provided, pyplot should already be available
152
+            # and we can import it safely without additional warnings.
153
+            from matplotlib import pyplot
154
+            pyplot.show()
155
+    elif get_backend() == _p.Backends.plotly:
156
+        if file is not None:
157
+            _p.plotly.plot(fig, show_link=False, filename=file, auto_open=False)
158
+        if show:
159
+            _p.plotly.iplot(fig)
116 160
 
117 161
 
118 162
 def set_colors(color, collection, cmap, norm=None):
... ...
@@ -585,7 +585,8 @@ def main():
585 585
                             'tinyarray >= 1.2'],
586 586
           extras_require={
587 587
               # The oldest versions between: Debian stable, Ubuntu LTS
588
-              'plotting': 'matplotlib >= 2.1.1',
588
+              'plotting': ['matplotlib >= 2.1.1',
589
+                           'plotly >= 2.2.2'],
589 590
               'continuum': 'sympy >= 1.1.1',
590 591
               # qsymm is only packaged on PyPI
591 592
               'qsymm': 'qsymm >= 1.2.6',