Browse code

Add 3d implementation for plotly backend of plot().

Kelvin Loh authored on 20/12/2018 04:51:13
Showing 2 changed files
... ...
@@ -109,6 +109,14 @@ if plotly_available:
109 109
         "h": 14
110 110
     }
111 111
 
112
+    converter_map_3d = {
113
+        "o": "circle",
114
+        "s": "square",
115
+        "+": "cross",
116
+        "x": "x",
117
+        "d": "diamond",
118
+    }
119
+
112 120
 
113 121
     def convert_symbol_mpl_plotly(mpl_symbol):
114 122
         if isarray(mpl_symbol):
... ...
@@ -122,6 +130,23 @@ if plotly_available:
122 130
                                 mpl_symbol, converter_map.keys()))
123 131
         return converted_symbol
124 132
 
133
+    def convert_symbol_mpl_plotly_3d(mpl_symbol):
134
+        if isarray(mpl_symbol) or isinstance(mpl_symbol, tuple):
135
+            converted_symbol = [converter_map_3d.get(i) for i in mpl_symbol]
136
+            if None in converted_symbol:
137
+                raise RuntimeError('Input tuple \'{}\' not supported. '
138
+                                'Only the following characters '
139
+                                'are supported: {}'.format(
140
+                                    mpl_symbol, converter_map_3d.keys()))
141
+        else:
142
+            converted_symbol = converter_map_3d.get(mpl_symbol)
143
+            if converted_symbol == None:
144
+                raise RuntimeError('Input symbol \'{}\' not supported. '
145
+                                'Only the following are supported: {}'.format(
146
+                                    mpl_symbol, converter_map_3d.keys()))
147
+
148
+        return converted_symbol
149
+
125 150
 
126 151
     def convert_site_size_mpl_plotly(mpl_site_size, plotly_ref_px):
127 152
         return np.sqrt(mpl_site_size)*(96.0/72.0)*plotly_ref_px
... ...
@@ -136,7 +161,7 @@ if plotly_available:
136 161
     def convert_cmap_list_mpl_plotly(mpl_cmap_name, N=255):
137 162
         cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name)
138 163
         cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl)
139
-        level = np.linspace(1, 0, N)
164
+        level = np.linspace(0, 1, N)
140 165
         cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl))
141 166
                                 for level, cmap_mpl in zip(level,
142 167
                                                             cmap_mpl_arr)]
... ...
@@ -153,13 +178,12 @@ if plotly_available:
153 178
                                mpl_lead_cmap_end[2], N) * 255
154 179
         a_levels = np.linspace(mpl_lead_cmap_init[3],
155 180
                                mpl_lead_cmap_end[3], N)
156
-        level = np.linspace(1, 0, N)
181
+        level = np.linspace(0, 1, N)
157 182
         cmap_plotly_linear = [(level, 'rgba({},{},{},{})'.format(*rgba))
158
-                                    for level, rgba in zip(level, zip(r_levels,
159
-                                                                      g_levels,
160
-                                                                      b_levels,
161
-                                                                      a_levels
162
-                                                                      ))]
183
+                                for level, rgba in zip(level,
184
+                                                        zip(r_levels, g_levels,
185
+                                                            b_levels, a_levels
186
+                                                            ))]
163 187
         return cmap_plotly_linear
164 188
 
165 189
 
... ...
@@ -987,7 +987,6 @@ def _plot_plotly(sys, num_lead_cells, unit,
987 987
         raise RuntimeError("plotly was not found, but is required "
988 988
                            "for plot()")
989 989
 
990
-    print('In _plot_plotly')
991 990
     syst = sys  # for naming consistency inside function bodies
992 991
     # Generate data.
993 992
     sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells)
... ...
@@ -1020,7 +1019,6 @@ def _plot_plotly(sys, num_lead_cells, unit,
1020 1019
         raise RuntimeError('Plotly backend currently only supports '
1021 1020
                          'the pt symbol size unit')
1022 1021
 
1023
-
1024 1022
     site_symbol = _make_proper_site_spec('site_symbol', site_symbol, sites)
1025 1023
     if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
1026 1024
     # separate different symbols (not done in 3D, the separation
... ...
@@ -1073,15 +1071,10 @@ def _plot_plotly(sys, num_lead_cells, unit,
1073 1071
     if hop_color is None: hop_color = defaults['hop_color'][dim]
1074 1072
     if hop_lw is None: hop_lw = defaults['hop_lw'][dim]
1075 1073
 
1076
-    # if symbols are split up into different collections,
1077
-    # the colormapping will fail without normalization
1078
-    norm = None
1079 1074
     if len(symbol_slcs) > 1:
1080 1075
         try:
1081 1076
             if site_color.ndim == 1 and len(site_color) == n_syst_sites:
1082 1077
                 site_color = np.asarray(site_color, dtype=float)
1083
-                norm = _p.matplotlib.colors.Normalize(site_color.min(),
1084
-                                                      site_color.max())
1085 1078
         except:
1086 1079
             pass
1087 1080
 
... ...
@@ -1112,7 +1105,10 @@ def _plot_plotly(sys, num_lead_cells, unit,
1112 1105
             cmap, hop_cmap = cmap
1113 1106
         except TypeError:
1114 1107
             pass
1115
-    # plot system sites and hoppings
1108
+    # Plot system sites and hoppings
1109
+
1110
+    # First plot the nodes (sites) of the graph
1111
+    assert dim == 2 or dim == 3
1116 1112
     site_node_trace, site_edge_trace = [], []
1117 1113
     for symbol, slc in symbol_slcs:
1118 1114
         size = site_size[slc] if _p.isarray(site_size) else site_size
... ...
@@ -1121,109 +1117,144 @@ def _plot_plotly(sys, num_lead_cells, unit,
1121 1117
                    site_edgecolor)
1122 1118
         lw = site_lw[slc] if _p.isarray(site_lw) else site_lw
1123 1119
 
1124
-        site_symbol_plotly = _p.convert_symbol_mpl_plotly(symbol)
1125
-        site_node_trace_elem = _p.plotly_graph_objs.Scatter(
1126
-                        x=[],
1127
-                        y=[],
1128
-                        text=[],
1129
-                        mode='markers',
1130
-                        hoverinfo='text',
1131
-                        marker=dict(
1132
-                            showscale=False,
1133
-                            colorscale=_p.convert_cmap_list_mpl_plotly(cmap),
1134
-                            reversescale=True,
1135
-                            color=col,
1136
-                            size=_p.convert_site_size_mpl_plotly(size,
1137
-                                       defaults['plotly_site_size_reference']),
1138
-                            symbol=site_symbol_plotly,
1139
-                            line=dict(width=lw,
1140
-                                      color=edgecol)
1141
-                            ))
1142
-
1143
-
1144
-        for i in range(len(sites_pos[slc])):
1145
-            x, y = sites_pos[slc][i]
1146
-            site_node_trace_elem['x'] += tuple([x])
1147
-            site_node_trace_elem['y'] += tuple([y])
1120
+        if dim == 3:
1121
+            site_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=[], y=[],
1122
+                                                                  z=[])
1123
+            for i in range(len(sites_pos[slc])):
1124
+                x, y, z = sites_pos[slc][i]
1125
+                site_node_trace_elem.x += tuple([x])
1126
+                site_node_trace_elem.y += tuple([y])
1127
+                site_node_trace_elem.z += tuple([z])
1128
+            site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly_3d(
1129
+                                                                        symbol)
1130
+        else:
1131
+            site_node_trace_elem = _p.plotly_graph_objs.Scatter(x=[], y=[])
1132
+            for i in range(len(sites_pos[slc])):
1133
+                x, y = sites_pos[slc][i]
1134
+                site_node_trace_elem.x += tuple([x])
1135
+                site_node_trace_elem.y += tuple([y])
1136
+            site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly(
1137
+                                                                        symbol)
1138
+
1139
+        site_node_trace_elem.mode = 'markers'
1140
+        site_node_trace_elem.hoverinfo = 'text'
1141
+        site_node_trace_elem.marker.showscale = False
1142
+        site_node_trace_elem.marker.colorscale = \
1143
+                                          _p.convert_cmap_list_mpl_plotly(cmap)
1144
+        site_node_trace_elem.marker.reversescale = False
1145
+        site_node_trace_elem.marker.color = col
1146
+        site_node_trace_elem.marker.size = \
1147
+                                _p.convert_site_size_mpl_plotly(size,
1148
+                                        defaults['plotly_site_size_reference'])
1149
+
1150
+        site_node_trace_elem.line.width = lw
1151
+        site_node_trace_elem.line.color = edgecol
1148 1152
 
1149 1153
         site_node_trace.append(site_node_trace_elem)
1150 1154
 
1155
+    # Now plot the edges (hops) of the graph
1151 1156
     end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops]
1152
-    dim = end.shape[1]
1153
-    assert dim == 2 or dim == 3
1154
-    if dim == 2:
1155
-        site_edge_trace_elem = _p.plotly_graph_objs.Scatter(
1156
-                                x=[],
1157
-                                y=[],
1158
-                                line=dict(width=hop_lw,color=hop_color),
1159
-                                hoverinfo='none',
1160
-                                mode='lines')
1157
+
1158
+    if dim == 3:
1159
+        site_edge_trace_elem = _p.plotly_graph_objs.Scatter3d(x=[], y=[], z=[])
1160
+        for i in range(len(end)):
1161
+            x0, y0, z0 = end[i]
1162
+            x1, y1, z1 = start[i]
1163
+            site_edge_trace_elem.x += tuple([x0, x1, None])
1164
+            site_edge_trace_elem.y += tuple([y0, y1, None])
1165
+            site_edge_trace_elem.z += tuple([z0, z1, None])
1166
+    else:
1167
+        site_edge_trace_elem = _p.plotly_graph_objs.Scatter(x=[], y=[])
1161 1168
         for i in range(len(end)):
1162 1169
             x0, y0 = end[i]
1163 1170
             x1, y1 = start[i]
1164
-            site_edge_trace_elem['x'] += tuple([x0, x1, None])
1165
-            site_edge_trace_elem['y'] += tuple([y0, y1, None])
1166
-        site_edge_trace.append(site_edge_trace_elem)
1167
-    else:
1168
-        raise RuntimeError('dim=3 is unsupported yet in plotly backend')
1171
+            site_edge_trace_elem.x += tuple([x0, x1, None])
1172
+            site_edge_trace_elem.y += tuple([y0, y1, None])
1169 1173
 
1170
-    # Make conversion of colormap
1171 1174
 
1172
-    lead_site_symbol_plotly = _p.convert_symbol_mpl_plotly(lead_site_symbol)
1175
+    site_edge_trace_elem.line.width = hop_lw
1176
+    site_edge_trace_elem.line.color = hop_color
1177
+    site_edge_trace_elem.hoverinfo = 'none'
1178
+    site_edge_trace_elem.mode = 'lines'
1179
+    site_edge_trace.append(site_edge_trace_elem)
1180
+
1181
+    # Plot lead sites and edges
1173 1182
 
1174 1183
     lead_node_trace, lead_edge_trace = [], []
1175 1184
     for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs):
1176 1185
         lead_site_colors = np.array([i[2] for i in sites[sites_slc]],
1177 1186
                                     dtype=float)
1178
-        lead_node_trace_elem = _p.plotly_graph_objs.Scatter(
1179
-                        x=[],
1180
-                        y=[],
1181
-                        text=[],
1182
-                        mode='markers',
1183
-                        hoverinfo='text',
1184
-                        marker=dict(
1185
-                            showscale=False,
1186
-                            reversescale=True,
1187
-                            color=lead_site_colors,
1188
-                            colorscale=_p.convert_lead_cmap_mpl_plotly(
1189
-                                            lead_color, [1,1,1,lead_color[3]]),
1190
-                            size=_p.convert_site_size_mpl_plotly(
1191
-                                   lead_site_size,
1192
-                                   defaults['plotly_site_size_reference']),
1193
-                            symbol=lead_site_symbol_plotly,
1194
-                            line=dict(width=lead_site_lw,
1195
-                                      color=lead_site_edgecolor)
1196
-                            ))
1197
-        for i in range(len(sites_pos[sites_slc])):
1198
-            x, y = sites_pos[sites_slc][i]
1199
-            lead_node_trace_elem['x'] += tuple([x])
1200
-            lead_node_trace_elem['y'] += tuple([y])
1187
+        if dim == 3:
1188
+            lead_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=[], y=[],
1189
+                                                                  z=[])
1190
+            for i in range(len(sites_pos[sites_slc])):
1191
+                x, y, z = sites_pos[sites_slc][i]
1192
+                lead_node_trace_elem.x += tuple([x])
1193
+                lead_node_trace_elem.y += tuple([y])
1194
+                lead_node_trace_elem.z += tuple([z])
1195
+            lead_node_trace_elem.marker.symbol = \
1196
+                              _p.convert_symbol_mpl_plotly_3d(lead_site_symbol)
1197
+        else:
1198
+            lead_node_trace_elem = _p.plotly_graph_objs.Scatter(x=[], y=[])
1199
+            for i in range(len(sites_pos[sites_slc])):
1200
+                x, y = sites_pos[sites_slc][i]
1201
+                lead_node_trace_elem.x += tuple([x])
1202
+                lead_node_trace_elem.y += tuple([y])
1203
+            lead_node_trace_elem.marker.symbol = \
1204
+                                 _p.convert_symbol_mpl_plotly(lead_site_symbol)
1205
+
1206
+        lead_node_trace_elem.mode = 'markers'
1207
+        lead_node_trace_elem.hoverinfo = 'text'
1208
+        lead_node_trace_elem.marker.showscale = False
1209
+        lead_node_trace_elem.marker.reversescale = False
1210
+        lead_node_trace_elem.marker.color = lead_site_colors
1211
+        lead_node_trace_elem.marker.colorscale = \
1212
+                                    _p.convert_lead_cmap_mpl_plotly(lead_color,
1213
+                                                      [1, 1, 1, lead_color[3]])
1214
+        lead_node_trace_elem.marker.size = _p.convert_site_size_mpl_plotly(
1215
+                                        lead_site_size,
1216
+                                        defaults['plotly_site_size_reference'])
1217
+
1218
+        if _p.isarray(lead_site_lw) or _p.isarray(lead_site_edgecolor):
1219
+            raise RuntimeError("Plotly backend not currently support an array "
1220
+                               "of linecolors or linewidths. Please restrict "
1221
+                               "to only a constant (i.e. no function or array) "
1222
+                               "lead_site_lw and lead_site_edgecolor property "
1223
+                               "for the entire plot.")
1224
+        lead_node_trace_elem.line.width = lead_site_lw
1225
+        lead_node_trace_elem.line.color = lead_site_edgecolor
1226
+
1201 1227
         lead_node_trace.append(lead_node_trace_elem)
1228
+
1202 1229
         lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float)
1203
-        # Note: the previous version of the code had in addition this
1204
-        # line in the 3D case:
1205
-        # lead_hop_colors = 1 / np.sqrt(1. + lead_hop_colors)
1206
-        # Uses lead_cmap for the colormap
1207
-        # 1) Make each line a scatter object. Takes a lot of memory but should work
1208
-        # 2) Get the color from the previous object
1230
+
1209 1231
         end, start = end_pos[hops_slc], start_pos[hops_slc]
1210
-        if dim == 2:
1211
-            lead_edge_trace_elem = _p.plotly_graph_objs.Scatter(
1212
-                                    x=[],
1213
-                                    y=[],
1214
-                                    line=dict(width=lead_hop_lw,
1215
-                                              color='red'),
1216
-                                    hoverinfo='none',
1217
-                                    mode='lines')
1232
+
1233
+        if dim == 3:
1234
+            lead_edge_trace_elem = _p.plotly_graph_objs.Scatter3d(x=[], y=[],
1235
+                                                                  z=[])
1236
+            for i in range(len(end)):
1237
+                x0, y0, z0 = end[i]
1238
+                x1, y1, z1 = start[i]
1239
+                lead_edge_trace_elem.x += tuple([x0, x1, None])
1240
+                lead_edge_trace_elem.y += tuple([y0, y1, None])
1241
+                lead_edge_trace_elem.z += tuple([z0, z1, None])
1242
+
1243
+        else:
1244
+            lead_edge_trace_elem = _p.plotly_graph_objs.Scatter(x=[], y=[])
1218 1245
             for i in range(len(end)):
1219 1246
                 x0, y0 = end[i]
1220 1247
                 x1, y1 = start[i]
1221
-                lead_edge_trace_elem['x'] += tuple([x0, x1, None])
1222
-                lead_edge_trace_elem['y'] += tuple([y0, y1, None])
1248
+                lead_edge_trace_elem.x += tuple([x0, x1, None])
1249
+                lead_edge_trace_elem.y += tuple([y0, y1, None])
1223 1250
 
1224
-            lead_edge_trace.append(lead_edge_trace_elem)
1225
-        else:
1226
-            raise RuntimeError('dim=3 is unsupported yet in plotly backend')
1251
+        lead_edge_trace_elem.line.width = lead_hop_lw
1252
+        lead_edge_trace_elem.line.color = _p.convert_colormap_mpl_plotly(
1253
+                                                        lead_color)
1254
+        lead_edge_trace_elem.hoverinfo = 'none'
1255
+        lead_edge_trace_elem.mode = 'lines'
1256
+
1257
+        lead_edge_trace.append(lead_edge_trace_elem)
1227 1258
 
1228 1259
     layout = _p.plotly_graph_objs.Layout(
1229 1260
                                 showlegend=False,
... ...
@@ -1255,7 +1286,6 @@ def _plot_matplotlib(sys, num_lead_cells, unit,
1255 1286
         raise RuntimeError("matplotlib was not found, but is required "
1256 1287
                            "for plot()")
1257 1288
 
1258
-    print('In _plot_matplotlib')
1259 1289
     syst = sys  # for naming consistency inside function bodies
1260 1290
     # Generate data.
1261 1291
     sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells)