Browse code

Adds the plotly testing script.

Kelvin Loh authored on 11/12/2019 15:49:54
Showing 3 changed files
... ...
@@ -174,12 +174,17 @@ if plotly_available:
174 174
 
175 175
 
176 176
     def convert_cmap_list_mpl_plotly(mpl_cmap_name, N=255):
177
-        cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name)
178
-        cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl)
179
-        level = np.linspace(0, 1, N)
180
-        cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl))
181
-                                for level, cmap_mpl in zip(level,
182
-                                                            cmap_mpl_arr)]
177
+        if isinstance(mpl_cmap_name, str):
178
+            cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name)
179
+            cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl)
180
+            level = np.linspace(0, 1, N)
181
+            cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl))
182
+                                    for level, cmap_mpl in zip(level,
183
+                                                                cmap_mpl_arr)]
184
+        else:
185
+            assert(isinstance(mpl_cmap_name, list))
186
+            # Do not do any conversion if it's already a list
187
+            cmap_plotly_linear = mpl_cmap_name
183 188
         return cmap_plotly_linear
184 189
 
185 190
 
... ...
@@ -1612,6 +1612,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
1612 1612
                  range(len(cmin)))
1613 1613
     grid = tuple(np.ogrid[dims])
1614 1614
     img = interpolate.griddata(coords, values, grid, method)
1615
+    img = img.astype(np.float_)
1615 1616
     mask = np.mgrid[dims].reshape(len(cmin), -1).T
1616 1617
     # The numerical values in the following line are optimized for the common
1617 1618
     # case of a square lattice:
... ...
@@ -1793,7 +1794,8 @@ def _map_plotly(syst, img, colorbar, _max, _min,  vmin, vmax, overflow_pct,
1793 1794
     contour_object.y = np.linspace(_min[1],_max[1],img.shape[1])
1794 1795
     contour_object.zsmooth = False
1795 1796
     contour_object.connectgaps = False
1796
-    contour_object.colorscale = _p.convert_cmap_list_mpl_plotly(cmap)
1797
+    cmap = _p.convert_cmap_list_mpl_plotly(cmap)
1798
+    contour_object.colorscale = cmap
1797 1799
     contour_object.zmax = vmax
1798 1800
     contour_object.zmin = vmin
1799 1801
     contour_object.hoverinfo = 'none'
... ...
@@ -116,18 +116,31 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
116 116
     return syst
117 117
 
118 118
 
119
+def plotter_file_suffix(engine):
120
+    # We need this function so that we can add a .html suffix to the output filename.
121
+    # This is required because plotly will throw an error if filename is without the suffix.
122
+    if engine == "plotly":
123
+        return ".html"
124
+    else:
125
+        return None
126
+
127
+
119 128
 @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
120
-def test_plot():
129
+def test_matplotlib_plot():
130
+
131
+    plotter.set_engine('matplotlib')
121 132
     plot = plotter.plot
122 133
     syst2d = syst_2d()
123 134
     syst3d = syst_3d()
124 135
     color_opts = ['k', (lambda site: site.tag[0]),
125 136
                   lambda site: (abs(site.tag[0] / 100),
126 137
                                 abs(site.tag[1] / 100), 0)]
127
-    with tempfile.TemporaryFile('w+b') as out:
138
+    engine = plotter.get_engine()
139
+    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
140
+        out_filename = out.name
128 141
         for color in color_opts:
129 142
             for syst in (syst2d, syst3d):
130
-                fig = plot(syst, site_color=color, cmap='binary', file=out)
143
+                fig = plot(syst, site_color=color, cmap='binary', file=out_filename)
131 144
                 if (color != 'k' and
132 145
                     isinstance(color(next(iter(syst2d.sites()))), float)):
133 146
                     assert fig.axes[0].collections[0].get_array() is not None
... ...
@@ -137,30 +150,66 @@ def test_plot():
137 150
                                            abs(site.tag[1] / 100), 0)]
138 151
         for color in color_opts:
139 152
             for syst in (syst2d, syst3d):
140
-                fig = plot(syst2d, hop_color=color, cmap='binary', file=out,
153
+                fig = plot(syst2d, hop_color=color, cmap='binary', file=out_filename,
141 154
                            fig_size=(2, 10), dpi=30)
142 155
                 if color != 'k' and isinstance(color(next(iter(syst2d.sites())),
143 156
                                                           None), float):
144 157
                     assert fig.axes[0].collections[1].get_array() is not None
145 158
 
146
-        assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D)
159
+        assert isinstance(plot(syst3d, file=out_filename).axes[0], mplot3d.axes3d.Axes3D)
160
+
161
+        syst2d.leads = []
162
+        plot(syst2d, file=out_filename)
163
+        del syst2d[list(syst2d.hoppings())]
164
+        plot(syst2d, file=out_filename)
165
+
166
+        plot(syst3d, file=out_filename)
167
+        with warnings.catch_warnings():
168
+            warnings.simplefilter("ignore")
169
+            plot(syst2d.finalized(), file=out_filename)
170
+
171
+        # test 2D projections of 3D systems
172
+        plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2])
173
+
174
+
175
+@pytest.mark.skipif(not _plotter.plotly_available, reason="Plotly unavailable.")
176
+def test_plotly_plot():
177
+
178
+    plotter.set_engine('plotly')
179
+    plot = plotter.plot
180
+    syst2d = syst_2d()
181
+    syst3d = syst_3d()
182
+    color_opts = ['black', (lambda site: site.tag[0]),
183
+                  lambda site: (abs(site.tag[0] / 100),
184
+                                abs(site.tag[1] / 100), 0)]
185
+    engine = plotter.get_engine()
186
+    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
187
+        out_filename = out.name
188
+        for color in color_opts:
189
+            for syst in (syst2d, syst3d):
190
+                plot(syst, site_color=color, cmap='binary', file=out_filename, show=False)
191
+
192
+        color_opts = ['black', (lambda site, site2: site.tag[0]),
193
+                      lambda site, site2: (abs(site.tag[0] / 100),
194
+                                           abs(site.tag[1] / 100), 0)]
147 195
 
148 196
         syst2d.leads = []
149
-        plot(syst2d, file=out)
197
+        plot(syst2d, file=out_filename, show=False)
150 198
         del syst2d[list(syst2d.hoppings())]
151
-        plot(syst2d, file=out)
199
+        plot(syst2d, file=out_filename, show=False)
152 200
 
153
-        plot(syst3d, file=out)
201
+        plot(syst3d, file=out_filename, show=False)
154 202
         with warnings.catch_warnings():
155 203
             warnings.simplefilter("ignore")
156
-            plot(syst2d.finalized(), file=out)
204
+            plot(syst2d.finalized(), file=out_filename, show=False)
157 205
 
158 206
         # test 2D projections of 3D systems
159
-        plot(syst3d, file=out, pos_transform=lambda pos: pos[:2])
207
+        plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2], show=False)
160 208
 
161 209
 
162 210
 @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
163
-def test_plot_more_site_families_than_colors():
211
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
212
+def test_plot_more_site_families_than_colors(engine):
164 213
     # test against regression reported in
165 214
     # https://gitlab.kwant-project.org/kwant/kwant/issues/257
166 215
     ncolors = len(pyplot.rcParams['axes.prop_cycle'])
... ...
@@ -169,17 +218,23 @@ def test_plot_more_site_families_than_colors():
169 218
                 for i in range(ncolors + 1)]
170 219
     for i, lat in enumerate(lattices):
171 220
         syst[lat(i, 0)] = None
172
-    with tempfile.TemporaryFile('w+b') as out:
173
-        plotter.plot(syst, file=out)
221
+
222
+    plotter.set_engine(engine)
223
+    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
224
+        out_filename = out.name
225
+        print(out)
226
+        plotter.plot(syst, file=out_filename, show=False)
174 227
 
175 228
 
176 229
 @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
177
-def test_plot_raises_on_bad_site_spec():
230
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
231
+def test_plot_raises_on_bad_site_spec(engine):
178 232
     syst = kwant.Builder()
179 233
     lat = kwant.lattice.square(norbs=1)
180 234
     syst[(lat(i, j) for i in range(5) for j in range(5))] = None
181 235
 
182 236
     # Cannot provide site_size as an array when syst is a Builder
237
+    plotter.set_engine(engine)
183 238
     with pytest.raises(TypeError):
184 239
         plotter.plot(syst, site_size=[1] * 25)
185 240
 
... ...
@@ -197,20 +252,24 @@ def bad_transform(pos):
197 252
     return x, y, 0
198 253
 
199 254
 @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
200
-def test_map():
255
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
256
+def test_map(engine):
257
+    plotter.set_engine(engine)
201 258
     syst = syst_2d()
202
-    with tempfile.TemporaryFile('w+b') as out:
259
+
260
+    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
261
+        out_filename = out.name
203 262
         plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
204
-                    file=out, method='linear', a=4, oversampling=4, cmap='flag')
263
+                    file=out_filename, method='linear', a=4, oversampling=4, cmap='flag', show=False)
205 264
         pytest.raises(ValueError, plotter.map, syst,
206 265
                       lambda site: site.tag[0],
207
-                      pos_transform=bad_transform, file=out)
266
+                      pos_transform=bad_transform, file=out_filename)
208 267
         with warnings.catch_warnings():
209 268
             warnings.simplefilter("ignore")
210 269
             plotter.map(syst.finalized(), range(len(syst.sites())),
211
-                        file=out)
270
+                        file=out_filename, show=False)
212 271
         pytest.raises(ValueError, plotter.map, syst,
213
-                      range(len(syst.sites())), file=out)
272
+                      range(len(syst.sites())), file=out_filename)
214 273
 
215 274
 
216 275
 def test_mask_interpolate():
... ...
@@ -233,22 +292,32 @@ def test_mask_interpolate():
233 292
 
234 293
 
235 294
 @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
236
-def test_bands():
295
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
296
+def test_bands(engine):
297
+
298
+    plotter.set_engine(engine)
237 299
 
238 300
     syst = syst_2d().finalized().leads[0]
239 301
 
240
-    with tempfile.TemporaryFile('w+b') as out:
241
-        plotter.bands(syst, file=out)
242
-        plotter.bands(syst, fig_size=(10, 10), file=out)
243
-        plotter.bands(syst, momenta=np.linspace(0, 2 * np.pi), file=out)
302
+    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
303
+        out_filename = out.name
304
+        plotter.bands(syst, show=False, file=out_filename)
305
+        plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out_filename)
306
+
307
+        if engine == 'matplotlib':
308
+            plotter.bands(syst, show=False, fig_size=(10, 10), file=out_filename)
309
+
310
+            fig = pyplot.Figure()
311
+            ax = fig.add_subplot(1, 1, 1)
312
+            plotter.bands(syst, show=False, ax=ax, file=out_filename)
244 313
 
245
-        fig = pyplot.Figure()
246
-        ax = fig.add_subplot(1, 1, 1)
247
-        plotter.bands(syst, ax=ax, file=out)
248 314
 
249 315
 
250 316
 @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
251
-def test_spectrum():
317
+@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
318
+def test_spectrum(engine):
319
+
320
+    plotter.set_engine(engine)
252 321
 
253 322
     def ham_1d(a, b, c):
254 323
         return a**2 + b**2 + c**2
... ...
@@ -264,38 +333,43 @@ def test_spectrum():
264 333
 
265 334
     vals = np.linspace(0, 1, 3)
266 335
 
267
-    with tempfile.TemporaryFile('w+b') as out:
336
+    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
337
+        out_filename = out.name
268 338
 
269 339
         for ham in (ham_1d, ham_2d, fsyst):
270
-            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out)
271
-            # test with explicit figsize
272
-            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
273
-                             fig_size=(10, 10), file=out)
340
+            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out_filename, show=False)
341
+            if engine == 'matplotlib':
342
+                # test with explicit figsize
343
+                plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
344
+                                 fig_size=(10, 10), file=out_filename, show=False)
274 345
 
275 346
         for ham in (ham_1d, ham_2d, fsyst):
276 347
             plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
277
-                             params=dict(c=1), file=out)
278
-            # test with explicit figsize
279
-            plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
280
-                             params=dict(c=1), fig_size=(10, 10), file=out)
281
-
282
-        # test 2D plot and explicitly passing axis
283
-        fig = pyplot.figure()
284
-        ax = fig.add_subplot(1, 1, 1, projection='3d')
285
-        plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
286
-                         params=dict(c=1), ax=ax, file=out)
287
-        # explicitly pass axis without 3D support
288
-        ax = fig.add_subplot(1, 1, 1)
289
-        with pytest.raises(TypeError):
348
+                             params=dict(c=1), file=out_filename, show=False)
349
+            if engine == 'matplotlib':
350
+                # test with explicit figsize
351
+                plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
352
+                                 params=dict(c=1), fig_size=(10, 10), file=out_filename, show=False)
353
+
354
+        if engine == 'matplotlib':
355
+            # test 2D plot and explicitly passing axis
356
+            fig = pyplot.figure()
357
+            ax = fig.add_subplot(1, 1, 1, projection='3d')
290 358
             plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
291
-                             params=dict(c=1), ax=ax, file=out)
359
+                             params=dict(c=1), ax=ax, file=out_filename, show=False)
360
+            # explicitly pass axis without 3D support
361
+            ax = fig.add_subplot(1, 1, 1)
362
+            with pytest.raises(TypeError):
363
+                plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
364
+                                 params=dict(c=1), ax=ax, file=out_filename, show=False)
292 365
 
293 366
     def mask(a, b):
294 367
         return a > 0.5
295 368
 
296
-    with tempfile.TemporaryFile('w+b') as out:
369
+    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
370
+        out_filename = out.name
297 371
         plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
298
-                         mask=mask, file=out)
372
+                         mask=mask, file=out_filename, show=False)
299 373
 
300 374
 
301 375
 def syst_rect(lat, salt, W=3, L=50):
... ...
@@ -555,7 +629,7 @@ def test_current():
555 629
     current = J(kwant.wave_function(syst, energy=1)(1)[0])
556 630
 
557 631
     # Test good codepath
558
-    with tempfile.TemporaryFile('w+b') as out:
632
+    with tempfile.NamedTemporaryFile('w+b') as out:
559 633
         plotter.current(syst, current, file=out)
560 634
 
561 635
         fig = pyplot.Figure()