Browse code

Working heatmap implementation.

Kelvin Loh authored on 21/12/2018 04:36:56
Showing 2 changed files
... ...
@@ -106,7 +106,8 @@ if plotly_available:
106 106
         "x": 4,
107 107
         "*": 17,
108 108
         "d": 2,
109
-        "h": 14
109
+        "h": 14,
110
+        "no symbol": -1
110 111
     }
111 112
 
112 113
     converter_map_3d = {
... ...
@@ -981,7 +981,7 @@ def _plot_plotly(sys, num_lead_cells, unit,
981 981
          lead_site_edgecolor, lead_site_lw,
982 982
          lead_hop_lw, pos_transform,
983 983
          cmap, colorbar, file,
984
-         show):
984
+         show, fig=None):
985 985
 
986 986
     if not _p.plotly_available:
987 987
         raise RuntimeError("plotly was not found, but is required "
... ...
@@ -1111,6 +1111,9 @@ def _plot_plotly(sys, num_lead_cells, unit,
1111 1111
     assert dim == 2 or dim == 3
1112 1112
     site_node_trace, site_edge_trace = [], []
1113 1113
     for symbol, slc in symbol_slcs:
1114
+        site_symbol_plotly = _p.convert_symbol_mpl_plotly(symbol)
1115
+        if site_symbol_plotly == -1:
1116
+            continue
1114 1117
         size = site_size[slc] if _p.isarray(site_size) else site_size
1115 1118
         col = site_color[slc] if _p.isarray(site_color) else site_color
1116 1119
         if _p.isarray(site_edgecolor) or _p.isarray(site_lw):
... ...
@@ -1144,7 +1147,7 @@ def _plot_plotly(sys, num_lead_cells, unit,
1144 1147
                                                                         symbol)
1145 1148
 
1146 1149
         site_node_trace_elem.mode = 'markers'
1147
-        site_node_trace_elem.hoverinfo = 'text'
1150
+        site_node_trace_elem.hoverinfo = 'none'
1148 1151
         site_node_trace_elem.marker.showscale = False
1149 1152
         site_node_trace_elem.marker.colorscale = \
1150 1153
                                           _p.convert_cmap_list_mpl_plotly(cmap)
... ...
@@ -1158,6 +1161,7 @@ def _plot_plotly(sys, num_lead_cells, unit,
1158 1161
 
1159 1162
         site_node_trace_elem.line.width = lw
1160 1163
         site_node_trace_elem.line.color = edgecol
1164
+        site_node_trace_elem.showlegend = False
1161 1165
 
1162 1166
         site_node_trace.append(site_node_trace_elem)
1163 1167
 
... ...
@@ -1189,6 +1193,7 @@ def _plot_plotly(sys, num_lead_cells, unit,
1189 1193
     site_edge_trace_elem.line.width = hop_lw
1190 1194
     site_edge_trace_elem.line.color = hop_color
1191 1195
     site_edge_trace_elem.hoverinfo = 'none'
1196
+    site_edge_trace_elem.showlegend = False
1192 1197
     site_edge_trace_elem.mode = 'lines'
1193 1198
     site_edge_trace.append(site_edge_trace_elem)
1194 1199
 
... ...
@@ -1218,7 +1223,8 @@ def _plot_plotly(sys, num_lead_cells, unit,
1218 1223
                                  _p.convert_symbol_mpl_plotly(lead_site_symbol)
1219 1224
 
1220 1225
         lead_node_trace_elem.mode = 'markers'
1221
-        lead_node_trace_elem.hoverinfo = 'text'
1226
+        lead_node_trace_elem.hoverinfo = 'none'
1227
+        lead_node_trace_elem.showlegend = False
1222 1228
         lead_node_trace_elem.marker.showscale = False
1223 1229
         lead_node_trace_elem.marker.reversescale = False
1224 1230
         lead_node_trace_elem.marker.color = lead_site_colors
... ...
@@ -1238,7 +1244,8 @@ def _plot_plotly(sys, num_lead_cells, unit,
1238 1244
         lead_node_trace_elem.line.width = lead_site_lw
1239 1245
         lead_node_trace_elem.line.color = lead_site_edgecolor
1240 1246
 
1241
-        lead_node_trace.append(lead_node_trace_elem)
1247
+        if lead_node_trace_elem:
1248
+            lead_node_trace.append(lead_node_trace_elem)
1242 1249
 
1243 1250
         lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float)
1244 1251
 
... ...
@@ -1267,6 +1274,7 @@ def _plot_plotly(sys, num_lead_cells, unit,
1267 1274
                                                         lead_color)
1268 1275
         lead_edge_trace_elem.hoverinfo = 'none'
1269 1276
         lead_edge_trace_elem.mode = 'lines'
1277
+        lead_edge_trace_elem.showlegend = False
1270 1278
 
1271 1279
         lead_edge_trace.append(lead_edge_trace_elem)
1272 1280
 
... ...
@@ -1277,11 +1285,17 @@ def _plot_plotly(sys, num_lead_cells, unit,
1277 1285
                                            showticklabels=True),
1278 1286
                                 yaxis=dict(showgrid=False, zeroline=False,
1279 1287
                                            showticklabels=True))
1280
-    full_trace = list(itertools.chain.from_iterable([site_edge_trace,
1281
-                                        site_node_trace, lead_edge_trace,
1282
-                                        lead_node_trace]))
1283
-    fig = _p.plotly_graph_objs.Figure(data=full_trace,
1284
-                                      layout=layout)
1288
+    if fig == None:
1289
+        full_trace = list(itertools.chain.from_iterable([site_edge_trace,
1290
+                                            site_node_trace, lead_edge_trace,
1291
+                                            lead_node_trace]))
1292
+        fig = _p.plotly_graph_objs.Figure(data=full_trace,
1293
+                                          layout=layout)
1294
+    else:
1295
+        full_trace = list(itertools.chain.from_iterable([lead_edge_trace,
1296
+                                            lead_node_trace]))
1297
+        for trace in full_trace:
1298
+            fig.add_trace(trace)
1285 1299
 
1286 1300
     return fig
1287 1301
 
... ...
@@ -1610,7 +1624,14 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
1610 1624
     # * 0.4 (which is just below sqrt(2) - 1) makes tree.query() exact.
1611 1625
     mask = tree.query(mask, eps=0.4)[0] > 0.99 * a
1612 1626
 
1613
-    return np.ma.masked_array(img, mask), cmin, cmax
1627
+    masked_result_array = np.ma.masked_array(img, mask)
1628
+
1629
+    if get_backend() != _p.Backends.matplotlib:
1630
+        result_array = masked_result_array.filled(np.NaN)
1631
+    else:
1632
+        result_array = masked_result_array
1633
+
1634
+    return result_array, img, cmin, cmax
1614 1635
 
1615 1636
 
1616 1637
 def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
... ...
@@ -1689,7 +1710,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
1689 1710
     kwant.plotter.density
1690 1711
     """
1691 1712
 
1692
-    if not _p.mpl_available:
1713
+    if not (_p.mpl_available or _p.plotly_available):
1693 1714
         raise RuntimeError("matplotlib was not found, but is required "
1694 1715
                            "for map()")
1695 1716
 
... ...
@@ -1711,18 +1732,8 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
1711 1732
                              'for finalized systems.')
1712 1733
     value = np.array(value)
1713 1734
     with _common.reraise_warnings():
1714
-        img, min, max = mask_interpolate(coords, value, a, method, oversampling)
1715
-    border = 0.5 * (max - min) / (np.asarray(img.shape) - 1)
1716
-    min -= border
1717
-    max += border
1718
-    if ax is None:
1719
-        fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
1720
-        ax = fig.add_subplot(1, 1, 1, aspect='equal')
1721
-    else:
1722
-        fig = None
1723
-
1724
-    if cmap is None:
1725
-        cmap = _p.kwant_red_matplotlib
1735
+        img, unmasked_data, _min, _max = mask_interpolate(coords, value,
1736
+                                                       a, method, oversampling)
1726 1737
 
1727 1738
     # Calculate the min/max bounds for the colormap.
1728 1739
     # User-provided values take precedence.
... ...
@@ -1745,9 +1756,96 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
1745 1756
         warnings.warn(''.join(msg), RuntimeWarning, stacklevel=2)
1746 1757
     vmin, vmax = new_vmin, new_vmax
1747 1758
 
1759
+    if get_backend() == _p.Backends.matplotlib:
1760
+        fig = _map_matplotlib(syst, img, colorbar, _max, _min,  vmin, vmax,
1761
+                        overflow_pct, underflow_pct, cmap, num_lead_cells,
1762
+                        background, dpi, fig_size, ax, file)
1763
+    elif get_backend() == _p.Backends.plotly:
1764
+        fig = _map_plotly(syst, img, colorbar, _max, _min,  vmin, vmax,
1765
+                          overflow_pct, underflow_pct, cmap, num_lead_cells,
1766
+                          background)
1767
+    else:
1768
+        raise RuntimeError('Backend not supported by map().')
1769
+
1770
+    _maybe_output_fig(fig, file=file, show=show)
1771
+
1772
+    return fig
1773
+
1774
+
1775
+def _map_plotly(syst, img, colorbar, _max, _min,  vmin, vmax, overflow_pct,
1776
+                underflow_pct, cmap, num_lead_cells, background):
1777
+
1778
+    border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1)
1779
+    _min -= border
1780
+    _max += border
1781
+
1782
+    if cmap is None:
1783
+        cmap = _p._colormaps.kwant_red
1784
+
1785
+    # Note that we tell imshow to show the array created by mask_interpolate
1786
+    # faithfully and not to interpolate by itself another time.
1787
+    # image = ax.imshow(img.T, extent=(_min[0], _max[0], _min[1], _max[1]),
1788
+    #                   origin='lower', interpolation='none', cmap=cmap,
1789
+    #                   vmin=vmin, vmax=vmax)
1790
+    if not _p.plotly_available:
1791
+        raise RuntimeError("plotly was not found, but is required "
1792
+                           "for _map_plotly()")
1793
+
1794
+    if cmap is None:
1795
+        cmap = _p.kwant_red_plotly
1796
+
1797
+    img = img.T
1798
+    contour_object = _p.plotly_graph_objs.Heatmap()
1799
+    contour_object.z = img
1800
+    contour_object.x = np.linspace(_min[0],_max[0],img.shape[0])
1801
+    contour_object.y = np.linspace(_min[1],_max[1],img.shape[1])
1802
+    contour_object.zsmooth = False
1803
+    contour_object.connectgaps = False
1804
+    contour_object.colorscale = _p.convert_cmap_list_mpl_plotly(cmap)
1805
+    contour_object.zmax = vmax
1806
+    contour_object.zmin = vmin
1807
+    contour_object.showlegend = False
1808
+    contour_object.hoverinfo = 'none'
1809
+
1810
+    contour_object.showscale = colorbar
1811
+
1812
+    fig = _p.plotly_graph_objs.Figure(data=[contour_object])
1813
+    fig.layout.plot_bgcolor = background
1814
+    # fig.layout.width = 720
1815
+    # fig.layout.height = fig.layout.width
1816
+
1817
+    if num_lead_cells:
1818
+        fig = _plot_plotly(syst, num_lead_cells, site_symbol='no symbol',
1819
+                           hop_lw=0, lead_site_symbol='s',
1820
+                           lead_site_size=0.501, lead_site_lw=0,lead_hop_lw=0,
1821
+                           lead_color='black', colorbar=False, show=False,
1822
+                           fig=fig, unit='pt', site_size=None, site_color=None,
1823
+                           site_edgecolor=None, site_lw=0, hop_color=None,
1824
+                           lead_site_edgecolor=None,pos_transform=None,
1825
+                           cmap=None, file=None)
1826
+
1827
+    return fig
1828
+
1829
+
1830
+def _map_matplotlib(syst, img, colorbar, _max, _min,  vmin, vmax,
1831
+                    overflow_pct, underflow_pct, cmap, num_lead_cells,
1832
+                    background, dpi, fig_size, ax, file):
1833
+
1834
+    border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1)
1835
+    _min -= border
1836
+    _max += border
1837
+    if ax is None:
1838
+        fig = _make_figure(dpi, fig_size, use_pyplot=(file is None))
1839
+        ax = fig.add_subplot(1, 1, 1, aspect='equal')
1840
+    else:
1841
+        fig = None
1842
+
1843
+    if cmap is None:
1844
+        cmap = _p._colormaps.kwant_red
1845
+
1748 1846
     # Note that we tell imshow to show the array created by mask_interpolate
1749 1847
     # faithfully and not to interpolate by itself another time.
1750
-    image = ax.imshow(img.T, extent=(min[0], max[0], min[1], max[1]),
1848
+    image = ax.imshow(img.T, extent=(_min[0], _max[0], _min[1], _max[1]),
1751 1849
                       origin='lower', interpolation='none', cmap=cmap,
1752 1850
                       vmin=vmin, vmax=vmax)
1753 1851
     if num_lead_cells:
... ...
@@ -1768,8 +1866,6 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
1768 1866
             extend = 'max'
1769 1867
         fig.colorbar(image, extend=extend)
1770 1868
 
1771
-    _maybe_output_fig(fig, file=file, show=show)
1772
-
1773 1869
     return fig
1774 1870
 
1775 1871