... | ... |
@@ -174,12 +174,17 @@ if plotly_available: |
174 | 174 |
|
175 | 175 |
|
176 | 176 |
def convert_cmap_list_mpl_plotly(mpl_cmap_name, N=255): |
177 |
- cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name) |
|
178 |
- cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl) |
|
179 |
- level = np.linspace(0, 1, N) |
|
180 |
- cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl)) |
|
181 |
- for level, cmap_mpl in zip(level, |
|
182 |
- cmap_mpl_arr)] |
|
177 |
+ if isinstance(mpl_cmap_name, str): |
|
178 |
+ cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name) |
|
179 |
+ cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl) |
|
180 |
+ level = np.linspace(0, 1, N) |
|
181 |
+ cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl)) |
|
182 |
+ for level, cmap_mpl in zip(level, |
|
183 |
+ cmap_mpl_arr)] |
|
184 |
+ else: |
|
185 |
+ assert(isinstance(mpl_cmap_name, list)) |
|
186 |
+ # Do not do any conversion if it's already a list |
|
187 |
+ cmap_plotly_linear = mpl_cmap_name |
|
183 | 188 |
return cmap_plotly_linear |
184 | 189 |
|
185 | 190 |
|
... | ... |
@@ -1612,6 +1612,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): |
1612 | 1612 |
range(len(cmin))) |
1613 | 1613 |
grid = tuple(np.ogrid[dims]) |
1614 | 1614 |
img = interpolate.griddata(coords, values, grid, method) |
1615 |
+ img = img.astype(np.float_) |
|
1615 | 1616 |
mask = np.mgrid[dims].reshape(len(cmin), -1).T |
1616 | 1617 |
# The numerical values in the following line are optimized for the common |
1617 | 1618 |
# case of a square lattice: |
... | ... |
@@ -1793,7 +1794,8 @@ def _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax, overflow_pct, |
1793 | 1794 |
contour_object.y = np.linspace(_min[1],_max[1],img.shape[1]) |
1794 | 1795 |
contour_object.zsmooth = False |
1795 | 1796 |
contour_object.connectgaps = False |
1796 |
- contour_object.colorscale = _p.convert_cmap_list_mpl_plotly(cmap) |
|
1797 |
+ cmap = _p.convert_cmap_list_mpl_plotly(cmap) |
|
1798 |
+ contour_object.colorscale = cmap |
|
1797 | 1799 |
contour_object.zmax = vmax |
1798 | 1800 |
contour_object.zmin = vmin |
1799 | 1801 |
contour_object.hoverinfo = 'none' |
... | ... |
@@ -116,18 +116,31 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0): |
116 | 116 |
return syst |
117 | 117 |
|
118 | 118 |
|
119 |
+def plotter_file_suffix(engine): |
|
120 |
+ # We need this function so that we can add a .html suffix to the output filename. |
|
121 |
+ # This is required because plotly will throw an error if filename is without the suffix. |
|
122 |
+ if engine == "plotly": |
|
123 |
+ return ".html" |
|
124 |
+ else: |
|
125 |
+ return None |
|
126 |
+ |
|
127 |
+ |
|
119 | 128 |
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") |
120 |
-def test_plot(): |
|
129 |
+def test_matplotlib_plot(): |
|
130 |
+ |
|
131 |
+ plotter.set_engine('matplotlib') |
|
121 | 132 |
plot = plotter.plot |
122 | 133 |
syst2d = syst_2d() |
123 | 134 |
syst3d = syst_3d() |
124 | 135 |
color_opts = ['k', (lambda site: site.tag[0]), |
125 | 136 |
lambda site: (abs(site.tag[0] / 100), |
126 | 137 |
abs(site.tag[1] / 100), 0)] |
127 |
- with tempfile.TemporaryFile('w+b') as out: |
|
138 |
+ engine = plotter.get_engine() |
|
139 |
+ with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: |
|
140 |
+ out_filename = out.name |
|
128 | 141 |
for color in color_opts: |
129 | 142 |
for syst in (syst2d, syst3d): |
130 |
- fig = plot(syst, site_color=color, cmap='binary', file=out) |
|
143 |
+ fig = plot(syst, site_color=color, cmap='binary', file=out_filename) |
|
131 | 144 |
if (color != 'k' and |
132 | 145 |
isinstance(color(next(iter(syst2d.sites()))), float)): |
133 | 146 |
assert fig.axes[0].collections[0].get_array() is not None |
... | ... |
@@ -137,30 +150,66 @@ def test_plot(): |
137 | 150 |
abs(site.tag[1] / 100), 0)] |
138 | 151 |
for color in color_opts: |
139 | 152 |
for syst in (syst2d, syst3d): |
140 |
- fig = plot(syst2d, hop_color=color, cmap='binary', file=out, |
|
153 |
+ fig = plot(syst2d, hop_color=color, cmap='binary', file=out_filename, |
|
141 | 154 |
fig_size=(2, 10), dpi=30) |
142 | 155 |
if color != 'k' and isinstance(color(next(iter(syst2d.sites())), |
143 | 156 |
None), float): |
144 | 157 |
assert fig.axes[0].collections[1].get_array() is not None |
145 | 158 |
|
146 |
- assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D) |
|
159 |
+ assert isinstance(plot(syst3d, file=out_filename).axes[0], mplot3d.axes3d.Axes3D) |
|
160 |
+ |
|
161 |
+ syst2d.leads = [] |
|
162 |
+ plot(syst2d, file=out_filename) |
|
163 |
+ del syst2d[list(syst2d.hoppings())] |
|
164 |
+ plot(syst2d, file=out_filename) |
|
165 |
+ |
|
166 |
+ plot(syst3d, file=out_filename) |
|
167 |
+ with warnings.catch_warnings(): |
|
168 |
+ warnings.simplefilter("ignore") |
|
169 |
+ plot(syst2d.finalized(), file=out_filename) |
|
170 |
+ |
|
171 |
+ # test 2D projections of 3D systems |
|
172 |
+ plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2]) |
|
173 |
+ |
|
174 |
+ |
|
175 |
+@pytest.mark.skipif(not _plotter.plotly_available, reason="Plotly unavailable.") |
|
176 |
+def test_plotly_plot(): |
|
177 |
+ |
|
178 |
+ plotter.set_engine('plotly') |
|
179 |
+ plot = plotter.plot |
|
180 |
+ syst2d = syst_2d() |
|
181 |
+ syst3d = syst_3d() |
|
182 |
+ color_opts = ['black', (lambda site: site.tag[0]), |
|
183 |
+ lambda site: (abs(site.tag[0] / 100), |
|
184 |
+ abs(site.tag[1] / 100), 0)] |
|
185 |
+ engine = plotter.get_engine() |
|
186 |
+ with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: |
|
187 |
+ out_filename = out.name |
|
188 |
+ for color in color_opts: |
|
189 |
+ for syst in (syst2d, syst3d): |
|
190 |
+ plot(syst, site_color=color, cmap='binary', file=out_filename, show=False) |
|
191 |
+ |
|
192 |
+ color_opts = ['black', (lambda site, site2: site.tag[0]), |
|
193 |
+ lambda site, site2: (abs(site.tag[0] / 100), |
|
194 |
+ abs(site.tag[1] / 100), 0)] |
|
147 | 195 |
|
148 | 196 |
syst2d.leads = [] |
149 |
- plot(syst2d, file=out) |
|
197 |
+ plot(syst2d, file=out_filename, show=False) |
|
150 | 198 |
del syst2d[list(syst2d.hoppings())] |
151 |
- plot(syst2d, file=out) |
|
199 |
+ plot(syst2d, file=out_filename, show=False) |
|
152 | 200 |
|
153 |
- plot(syst3d, file=out) |
|
201 |
+ plot(syst3d, file=out_filename, show=False) |
|
154 | 202 |
with warnings.catch_warnings(): |
155 | 203 |
warnings.simplefilter("ignore") |
156 |
- plot(syst2d.finalized(), file=out) |
|
204 |
+ plot(syst2d.finalized(), file=out_filename, show=False) |
|
157 | 205 |
|
158 | 206 |
# test 2D projections of 3D systems |
159 |
- plot(syst3d, file=out, pos_transform=lambda pos: pos[:2]) |
|
207 |
+ plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2], show=False) |
|
160 | 208 |
|
161 | 209 |
|
162 | 210 |
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") |
163 |
-def test_plot_more_site_families_than_colors(): |
|
211 |
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) |
|
212 |
+def test_plot_more_site_families_than_colors(engine): |
|
164 | 213 |
# test against regression reported in |
165 | 214 |
# https://gitlab.kwant-project.org/kwant/kwant/issues/257 |
166 | 215 |
ncolors = len(pyplot.rcParams['axes.prop_cycle']) |
... | ... |
@@ -169,17 +218,23 @@ def test_plot_more_site_families_than_colors(): |
169 | 218 |
for i in range(ncolors + 1)] |
170 | 219 |
for i, lat in enumerate(lattices): |
171 | 220 |
syst[lat(i, 0)] = None |
172 |
- with tempfile.TemporaryFile('w+b') as out: |
|
173 |
- plotter.plot(syst, file=out) |
|
221 |
+ |
|
222 |
+ plotter.set_engine(engine) |
|
223 |
+ with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: |
|
224 |
+ out_filename = out.name |
|
225 |
+ print(out) |
|
226 |
+ plotter.plot(syst, file=out_filename, show=False) |
|
174 | 227 |
|
175 | 228 |
|
176 | 229 |
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") |
177 |
-def test_plot_raises_on_bad_site_spec(): |
|
230 |
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) |
|
231 |
+def test_plot_raises_on_bad_site_spec(engine): |
|
178 | 232 |
syst = kwant.Builder() |
179 | 233 |
lat = kwant.lattice.square(norbs=1) |
180 | 234 |
syst[(lat(i, j) for i in range(5) for j in range(5))] = None |
181 | 235 |
|
182 | 236 |
# Cannot provide site_size as an array when syst is a Builder |
237 |
+ plotter.set_engine(engine) |
|
183 | 238 |
with pytest.raises(TypeError): |
184 | 239 |
plotter.plot(syst, site_size=[1] * 25) |
185 | 240 |
|
... | ... |
@@ -197,20 +252,24 @@ def bad_transform(pos): |
197 | 252 |
return x, y, 0 |
198 | 253 |
|
199 | 254 |
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") |
200 |
-def test_map(): |
|
255 |
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) |
|
256 |
+def test_map(engine): |
|
257 |
+ plotter.set_engine(engine) |
|
201 | 258 |
syst = syst_2d() |
202 |
- with tempfile.TemporaryFile('w+b') as out: |
|
259 |
+ |
|
260 |
+ with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: |
|
261 |
+ out_filename = out.name |
|
203 | 262 |
plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform, |
204 |
- file=out, method='linear', a=4, oversampling=4, cmap='flag') |
|
263 |
+ file=out_filename, method='linear', a=4, oversampling=4, cmap='flag', show=False) |
|
205 | 264 |
pytest.raises(ValueError, plotter.map, syst, |
206 | 265 |
lambda site: site.tag[0], |
207 |
- pos_transform=bad_transform, file=out) |
|
266 |
+ pos_transform=bad_transform, file=out_filename) |
|
208 | 267 |
with warnings.catch_warnings(): |
209 | 268 |
warnings.simplefilter("ignore") |
210 | 269 |
plotter.map(syst.finalized(), range(len(syst.sites())), |
211 |
- file=out) |
|
270 |
+ file=out_filename, show=False) |
|
212 | 271 |
pytest.raises(ValueError, plotter.map, syst, |
213 |
- range(len(syst.sites())), file=out) |
|
272 |
+ range(len(syst.sites())), file=out_filename) |
|
214 | 273 |
|
215 | 274 |
|
216 | 275 |
def test_mask_interpolate(): |
... | ... |
@@ -233,22 +292,32 @@ def test_mask_interpolate(): |
233 | 292 |
|
234 | 293 |
|
235 | 294 |
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") |
236 |
-def test_bands(): |
|
295 |
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) |
|
296 |
+def test_bands(engine): |
|
297 |
+ |
|
298 |
+ plotter.set_engine(engine) |
|
237 | 299 |
|
238 | 300 |
syst = syst_2d().finalized().leads[0] |
239 | 301 |
|
240 |
- with tempfile.TemporaryFile('w+b') as out: |
|
241 |
- plotter.bands(syst, file=out) |
|
242 |
- plotter.bands(syst, fig_size=(10, 10), file=out) |
|
243 |
- plotter.bands(syst, momenta=np.linspace(0, 2 * np.pi), file=out) |
|
302 |
+ with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: |
|
303 |
+ out_filename = out.name |
|
304 |
+ plotter.bands(syst, show=False, file=out_filename) |
|
305 |
+ plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out_filename) |
|
306 |
+ |
|
307 |
+ if engine == 'matplotlib': |
|
308 |
+ plotter.bands(syst, show=False, fig_size=(10, 10), file=out_filename) |
|
309 |
+ |
|
310 |
+ fig = pyplot.Figure() |
|
311 |
+ ax = fig.add_subplot(1, 1, 1) |
|
312 |
+ plotter.bands(syst, show=False, ax=ax, file=out_filename) |
|
244 | 313 |
|
245 |
- fig = pyplot.Figure() |
|
246 |
- ax = fig.add_subplot(1, 1, 1) |
|
247 |
- plotter.bands(syst, ax=ax, file=out) |
|
248 | 314 |
|
249 | 315 |
|
250 | 316 |
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") |
251 |
-def test_spectrum(): |
|
317 |
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) |
|
318 |
+def test_spectrum(engine): |
|
319 |
+ |
|
320 |
+ plotter.set_engine(engine) |
|
252 | 321 |
|
253 | 322 |
def ham_1d(a, b, c): |
254 | 323 |
return a**2 + b**2 + c**2 |
... | ... |
@@ -264,38 +333,43 @@ def test_spectrum(): |
264 | 333 |
|
265 | 334 |
vals = np.linspace(0, 1, 3) |
266 | 335 |
|
267 |
- with tempfile.TemporaryFile('w+b') as out: |
|
336 |
+ with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: |
|
337 |
+ out_filename = out.name |
|
268 | 338 |
|
269 | 339 |
for ham in (ham_1d, ham_2d, fsyst): |
270 |
- plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out) |
|
271 |
- # test with explicit figsize |
|
272 |
- plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), |
|
273 |
- fig_size=(10, 10), file=out) |
|
340 |
+ plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out_filename, show=False) |
|
341 |
+ if engine == 'matplotlib': |
|
342 |
+ # test with explicit figsize |
|
343 |
+ plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), |
|
344 |
+ fig_size=(10, 10), file=out_filename, show=False) |
|
274 | 345 |
|
275 | 346 |
for ham in (ham_1d, ham_2d, fsyst): |
276 | 347 |
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), |
277 |
- params=dict(c=1), file=out) |
|
278 |
- # test with explicit figsize |
|
279 |
- plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), |
|
280 |
- params=dict(c=1), fig_size=(10, 10), file=out) |
|
281 |
- |
|
282 |
- # test 2D plot and explicitly passing axis |
|
283 |
- fig = pyplot.figure() |
|
284 |
- ax = fig.add_subplot(1, 1, 1, projection='3d') |
|
285 |
- plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals), |
|
286 |
- params=dict(c=1), ax=ax, file=out) |
|
287 |
- # explicitly pass axis without 3D support |
|
288 |
- ax = fig.add_subplot(1, 1, 1) |
|
289 |
- with pytest.raises(TypeError): |
|
348 |
+ params=dict(c=1), file=out_filename, show=False) |
|
349 |
+ if engine == 'matplotlib': |
|
350 |
+ # test with explicit figsize |
|
351 |
+ plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), |
|
352 |
+ params=dict(c=1), fig_size=(10, 10), file=out_filename, show=False) |
|
353 |
+ |
|
354 |
+ if engine == 'matplotlib': |
|
355 |
+ # test 2D plot and explicitly passing axis |
|
356 |
+ fig = pyplot.figure() |
|
357 |
+ ax = fig.add_subplot(1, 1, 1, projection='3d') |
|
290 | 358 |
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals), |
291 |
- params=dict(c=1), ax=ax, file=out) |
|
359 |
+ params=dict(c=1), ax=ax, file=out_filename, show=False) |
|
360 |
+ # explicitly pass axis without 3D support |
|
361 |
+ ax = fig.add_subplot(1, 1, 1) |
|
362 |
+ with pytest.raises(TypeError): |
|
363 |
+ plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals), |
|
364 |
+ params=dict(c=1), ax=ax, file=out_filename, show=False) |
|
292 | 365 |
|
293 | 366 |
def mask(a, b): |
294 | 367 |
return a > 0.5 |
295 | 368 |
|
296 |
- with tempfile.TemporaryFile('w+b') as out: |
|
369 |
+ with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: |
|
370 |
+ out_filename = out.name |
|
297 | 371 |
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1), |
298 |
- mask=mask, file=out) |
|
372 |
+ mask=mask, file=out_filename, show=False) |
|
299 | 373 |
|
300 | 374 |
|
301 | 375 |
def syst_rect(lat, salt, W=3, L=50): |
... | ... |
@@ -555,7 +629,7 @@ def test_current(): |
555 | 629 |
current = J(kwant.wave_function(syst, energy=1)(1)[0]) |
556 | 630 |
|
557 | 631 |
# Test good codepath |
558 |
- with tempfile.TemporaryFile('w+b') as out: |
|
632 |
+ with tempfile.NamedTemporaryFile('w+b') as out: |
|
559 | 633 |
plotter.current(syst, current, file=out) |
560 | 634 |
|
561 | 635 |
fig = pyplot.Figure() |