Browse code

add spectrum() implementation for plotly backend

Kelvin Loh authored on 06/12/2018 10:22:25
Showing 1 changed files
... ...
@@ -1479,7 +1479,10 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
1479 1479
 
1480 1480
 def spectrum(syst, x, y=None, params=None, mask=None, file=None,
1481 1481
              show=True, dpi=None, fig_size=None, ax=None):
1482
-    """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters
1482
+    """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters.
1483
+
1484
+    This function requires either matplotlib or plotly to be installed.
1485
+    The default backend uses matplotlib for plotting.
1483 1486
 
1484 1487
     Parameters
1485 1488
     ----------
... ...
@@ -1500,32 +1503,67 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
1500 1503
         if the spectrum should not be calculated for the given parameter
1501 1504
         values.
1502 1505
     file : string or file object or `None`
1503
-        The output file.  If `None`, output will be shown instead.
1506
+        The output file.  If `None`, output will be shown instead. If plotly is
1507
+        selected as the backend, the filename has to end with a html extension.
1504 1508
     show : bool
1505
-        Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
1506
-        to be shown immediately.  Defaults to `True`.
1509
+        For matplotlib backend, whether ``matplotlib.pyplot.show()`` is to be
1510
+        called, and the output is to be shown immediately.
1511
+        For the plotly backend, a call to ``iplot(fig)`` is made if
1512
+        show is True.
1513
+        Defaults to `True` for both backends.
1507 1514
     dpi : float
1508 1515
         Number of pixels per inch.  If not set the ``matplotlib`` default is
1509 1516
         used.
1517
+        Only for matplotlib backend. If the plotly backend is selected and
1518
+        this argument is not None, then a RuntimeError will be triggered.
1510 1519
     fig_size : tuple
1511 1520
         Figure size `(width, height)` in inches.  If not set, the default
1512 1521
         ``matplotlib`` value is used.
1522
+        Only for matplotlib backend. If the plotly backend is selected and
1523
+        this argument is not None, then a RuntimeError will be triggered.
1513 1524
     ax : ``matplotlib.axes.Axes`` instance or `None`
1514 1525
         If `ax` is not `None`, no new figure is created, but the plot is done
1515 1526
         within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
1516 1527
         and `fig_size` are ignored.
1528
+        Only for matplotlib backend. If the plotly backend is selected and
1529
+        this argument is not None, then a RuntimeError will be triggered.
1517 1530
 
1518 1531
     Returns
1519 1532
     -------
1520
-    fig : matplotlib figure
1521
-        A figure with the output if `ax` is not set, else None.
1533
+    fig : matplotlib figure or plotly Figure object
1522 1534
     """
1523 1535
 
1524
-    if not _p.mpl_available:
1525
-        raise RuntimeError("matplotlib was not found, but is required "
1526
-                           "for plot_spectrum()")
1527
-    if y is not None and not _p.has3d:
1528
-        raise RuntimeError("Installed matplotlib does not support 3d plotting")
1536
+    params = params or dict()
1537
+
1538
+    if get_backend() == _p.Backends.matplotlib:
1539
+        return _spectrum_matplotlib(syst, x, y, params, mask, file,
1540
+                                    show, dpi, fig_size, ax)
1541
+    elif get_backend() == _p.Backends.plotly:
1542
+        if(dpi or fig_size or ax):
1543
+            raise RuntimeError('Incompatible arguments of dpi, fig_size, or '
1544
+                               'ax. Current plotting backend is plotly.')
1545
+        return _spectrum_plotly(syst, x, y, params, mask, file, show)
1546
+
1547
+
1548
+def _generate_spectrum(syst, params, mask, x, y):
1549
+    """Generates the spectrum dataset for the internal plotting
1550
+    functions of spectrum().
1551
+
1552
+    Parameters
1553
+    ----------
1554
+    See spectrum(...) documentation.
1555
+
1556
+    Returns
1557
+    -------
1558
+    spectrum : Numpy array
1559
+         The energies of the system calculated at each coordinate.
1560
+    planar : bool
1561
+         True if y is None
1562
+    array_values : tuple
1563
+         The coordinates of x, y values of the dataset for plotting.
1564
+    keys : tuple
1565
+         Labels for the x and y axes.
1566
+    """
1529 1567
 
1530 1568
     if system.is_finite(syst):
1531 1569
         def ham(**kwargs):
... ...
@@ -1536,9 +1574,9 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
1536 1574
         raise TypeError("Expected 'syst' to be a finite Kwant system "
1537 1575
                         "or a function.")
1538 1576
 
1539
-    params = params or dict()
1540
-    keys = (x[0],) if y is None else (x[0], y[0])
1541
-    array_values = (x[1],) if y is None else (x[1], y[1])
1577
+    planar = y is None
1578
+    keys = (x[0],) if planar else (x[0], y[0])
1579
+    array_values = (x[1],) if planar else (x[1], y[1])
1542 1580
 
1543 1581
     # calculate spectrum on the grid of points
1544 1582
     spectrum = []
... ...
@@ -1558,10 +1596,91 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
1558 1596
     new_shape = [len(v) for v in array_values] + [-1]
1559 1597
     spectrum = np.array(spectrum).reshape(new_shape)
1560 1598
 
1599
+    return spectrum, planar, array_values, keys
1600
+
1601
+
1602
+def _spectrum_plotly(syst, x, y=None, params=None, mask=None,
1603
+                     file=None, show=True):
1604
+    """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters
1605
+    using the plotly backend.
1606
+
1607
+    Parameters
1608
+    ----------
1609
+    See spectrum(...) documentation.
1610
+
1611
+    Returns
1612
+    -------
1613
+    fig : plotly Figure / dict
1614
+    """
1615
+
1616
+    if not _p.plotly_available:
1617
+        raise RuntimeError("plotly was not found, but is required for using"
1618
+                           "the spectrum() plotly backend")
1619
+
1620
+    spectrum, planar, array_values, keys = _generate_spectrum(syst, params,
1621
+                                                              mask, x, y)
1622
+
1623
+    if planar:
1624
+        fig = _p.plotly_graph_objs.Figure(data=[
1625
+          _p.plotly_graph_objs.Scatter(
1626
+                 x=array_values[0],
1627
+                 y=energies,
1628
+          ) for energies in spectrum.T
1629
+        ])
1630
+        fig.layout.xaxis.title = keys[0]
1631
+        fig.layout.yaxis.title = 'Energy'
1632
+        fig.layout.showlegend = False
1633
+    else:
1634
+        fig = _p.plotly_graph_objs.Figure(data=[
1635
+          _p.plotly_graph_objs.Surface(
1636
+                 x=array_values[0],
1637
+                 y=array_values[1],
1638
+                 z=energies,
1639
+                 cmax=np.max(spectrum),
1640
+                 cmin=np.min(spectrum),
1641
+          ) for energies in spectrum.T
1642
+        ])
1643
+        fig.layout.scene.xaxis.title = keys[0]
1644
+        fig.layout.scene.yaxis.title = keys[1]
1645
+        fig.layout.scene.zaxis.title = 'Energy'
1646
+
1647
+    fig.layout.title = (
1648
+        ', '.join('{} = {}'.format(*kv) for kv in params.items())
1649
+    )
1650
+
1651
+    _maybe_output_fig(fig, file=file, show=show)
1652
+
1653
+    return fig
1654
+
1655
+
1656
+def _spectrum_matplotlib(syst, x, y=None, params=None, mask=None, file=None,
1657
+                         show=True, dpi=None, fig_size=None, ax=None):
1658
+    """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters
1659
+    using the matplotlib backend.
1660
+
1661
+    Parameters
1662
+    ----------
1663
+    See spectrum(...) documentation.
1664
+
1665
+    Returns
1666
+    -------
1667
+    fig : matplotlib figure
1668
+        A figure with the output if `ax` is not set, else None.
1669
+    """
1670
+
1671
+    if not _p.mpl_available:
1672
+        raise RuntimeError("matplotlib was not found, but is required for"
1673
+                           "using the spectrum() matplotlib backend")
1674
+    if y is not None and not _p.has3d:
1675
+        raise RuntimeError("Installed matplotlib does not support 3d plotting")
1676
+
1677
+    spectrum, planar, array_values, keys = _generate_spectrum(syst, params,
1678
+                                                              mask, x, y)
1679
+
1561 1680
     # set up axes
1562 1681
     if ax is None:
1563 1682
         fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
1564
-        if y is None:
1683
+        if planar:
1565 1684
             ax = fig.add_subplot(1, 1, 1)
1566 1685
         else:
1567 1686
             warnings.filterwarnings('ignore',
... ...
@@ -1569,7 +1688,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
1569 1688
             ax = fig.add_subplot(1, 1, 1, projection='3d')
1570 1689
             warnings.resetwarnings()
1571 1690
         ax.set_xlabel(keys[0])
1572
-        if y is None:
1691
+        if planar:
1573 1692
             ax.set_ylabel('Energy')
1574 1693
         else:
1575 1694
             ax.set_ylabel(keys[1])
... ...
@@ -1585,7 +1704,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
1585 1704
         fig = None
1586 1705
 
1587 1706
     # actually do the plot
1588
-    if y is None:
1707
+    if planar:
1589 1708
         ax.plot(array_values[0], spectrum)
1590 1709
     else:
1591 1710
         if not hasattr(ax, 'plot_surface'):