# Copyright 2011-2013 Kwant authors.
#
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# http://kwant-project.org/license.  A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.

import tempfile
import warnings
import itertools
import numpy as np
import tinyarray as ta
from math import cos, sin
import scipy.integrate
import scipy.stats
import pytest
import sys

import kwant
from kwant._common import ensure_rng

try:
    from mpl_toolkits import mplot3d
    import matplotlib

    # This check is the same as the one performed inside matplotlib.use.
    matplotlib_backend_chosen = 'matplotlib.backends' in sys.modules
    # If the user did not already choose a backend, then choose
    # the one with the least dependencies.
    if not matplotlib_backend_chosen:
        matplotlib.use('Agg')

    from matplotlib import pyplot  # pragma: no flakes
except ImportError:
    matplotlib_backend_chosen = False

from kwant import plotter, _plotter


@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_matplotlib_backend_unset():
    """Simply importing Kwant should not set the matplotlib backend."""
    assert matplotlib_backend_chosen is False


def test_importable_without_backends():
    prefix, sep, suffix = _plotter.__file__.rpartition('.')
    if suffix in ['pyc', 'pyo']:
        suffix = 'py'
    assert suffix == 'py'
    fname = sep.join((prefix, suffix))
    with open(fname, 'rb') as f:
        code = f.read()
    code = code.replace(b'from . import', b'from kwant import')
    code = code.replace(b'matplotlib', b'totalblimp')
    code = code.replace(b'plotly', b'plylot')

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        exec(code)               # Trigger the warning.
        assert len(w) == 2
        assert issubclass(w[0].category, RuntimeWarning)
        assert issubclass(w[1].category, RuntimeWarning)
        assert "totalblimp is not available" in str(w[0].message)
        assert "plylot is not available" in str(w[1].message)


def syst_2d(W=3, r1=3, r2=8):
    a = 1
    t = 1.0
    lat = kwant.lattice.square(a, norbs=1)
    syst = kwant.Builder()

    def ring(pos):
        (x, y) = pos
        rsq = x ** 2 + y ** 2
        return r1 ** 2 < rsq < r2 ** 2

    syst[lat.shape(ring, (0, r1 + 1))] = 4 * t
    syst[lat.neighbors()] = -t
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
    lead0 = kwant.Builder(sym_lead0)

    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2)

    lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
    lead0[lat.neighbors()] = - t
    lead1 = lead0.reversed()
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst


def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
    lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)), norbs=1)
    syst = kwant.Builder()

    def ring(pos):
        (x, y, z) = pos
        rsq = x ** 2 + y ** 2
        return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2
    syst[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t
    syst[lat.neighbors()] = - t
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
    lead0 = kwant.Builder(sym_lead0)

    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2

    lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
    lead0[lat.neighbors()] = - t
    lead1 = lead0.reversed()
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst


def plotter_file_suffix(engine):
    # We need this function so that we can add a .html suffix to the output filename.
    # This is required because plotly will throw an error if filename is without the suffix.
    if engine == "plotly":
        return ".html"
    else:
        return None


@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_matplotlib_plot():
    plotter.set_engine('matplotlib')
    plot = plotter.plot
    syst2d = syst_2d()
    syst3d = syst_3d()
    color_opts = ['k', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
    engine = plotter.get_engine()
    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        for color in color_opts:
            for syst in (syst2d, syst3d):
                fig = plot(syst, site_color=color, cmap='binary', file=out_filename)
                if (color != 'k' and
                    isinstance(color(next(iter(syst2d.sites()))), float)):
                    assert fig.axes[0].collections[0].get_array() is not None
                assert len(fig.axes[0].collections) == 6
        color_opts = ['k', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]
        for color in color_opts:
            for syst in (syst2d, syst3d):
                fig = plot(syst2d, hop_color=color, cmap='binary', file=out_filename,
                           fig_size=(2, 10), dpi=30)
                if color != 'k' and isinstance(color(next(iter(syst2d.sites())),
                                                          None), float):
                    assert fig.axes[0].collections[1].get_array() is not None

        assert isinstance(plot(syst3d, file=out_filename).axes[0], mplot3d.axes3d.Axes3D)

        syst2d.leads = []
        plot(syst2d, file=out_filename)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out_filename)

        plot(syst3d, file=out_filename)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plot(syst2d.finalized(), file=out_filename)

        # test 2D projections of 3D systems
        plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2])


@pytest.mark.skipif(not _plotter.plotly_available, reason="Plotly unavailable.")
def test_plotly_plot():
    plotter.set_engine('plotly')
    plot = plotter.plot
    syst2d = syst_2d()
    syst3d = syst_3d()
    color_opts = ['black', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
    engine = plotter.get_engine()
    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        for color in color_opts:
            for syst in (syst2d, syst3d):
                plot(syst, site_color=color, cmap='binary', file=out_filename, show=False)

        color_opts = ['black', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]

        syst2d.leads = []
        plot(syst2d, file=out_filename, show=False)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out_filename, show=False)

        plot(syst3d, file=out_filename, show=False)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plot(syst2d.finalized(), file=out_filename, show=False)

        # test 2D projections of 3D systems
        plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2], show=False)


@pytest.mark.parametrize("engine", _plotter.engines)
def test_plot_more_site_families_than_colors(engine):
    # test against regression reported in
    # https://gitlab.kwant-project.org/kwant/kwant/issues/257
    ncolors = len(pyplot.rcParams['axes.prop_cycle'])
    syst = kwant.Builder()
    lattices = [kwant.lattice.square(name=i, norbs=1)
                for i in range(ncolors + 1)]
    for i, lat in enumerate(lattices):
        syst[lat(i, 0)] = None

    plotter.set_engine(engine)
    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        print(out)
        plotter.plot(syst, file=out_filename, show=False)


@pytest.mark.parametrize("engine", _plotter.engines)
def test_plot_raises_on_bad_site_spec(engine):
    syst = kwant.Builder()
    lat = kwant.lattice.square(norbs=1)
    syst[(lat(i, j) for i in range(5) for j in range(5))] = None

    # Cannot provide site_size as an array when syst is a Builder
    plotter.set_engine(engine)
    with pytest.raises(TypeError):
        plotter.plot(syst, site_size=[1] * 25)

    # Cannot provide site_size as an array when syst is a Builder
    with pytest.raises(TypeError):
        plotter.plot(syst, site_symbol=['o'] * 25)


def good_transform(pos):
    x, y = pos
    return y, x

def bad_transform(pos):
    x, y = pos
    return x, y, 0

@pytest.mark.parametrize("engine", _plotter.engines)
def test_map(engine):
    plotter.set_engine(engine)
    syst = syst_2d()

    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
                    file=out_filename, method='linear', a=4, oversampling=4, cmap='flag', show=False)
        pytest.raises(ValueError, plotter.map, syst,
                      lambda site: site.tag[0],
                      pos_transform=bad_transform, file=out_filename)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plotter.map(syst.finalized(), range(len(syst.sites())),
                        file=out_filename, show=False)
        pytest.raises(ValueError, plotter.map, syst,
                      range(len(syst.sites())), file=out_filename)


def test_mask_interpolate():
    # A coordinate array with coordinates of two points almost coinciding.
    coords = np.array([[0, 0], [1e-7, 1e-7], [1, 1], [1, 0]])

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        plotter.mask_interpolate(coords, np.ones(len(coords)), a=1)
        assert len(w) == 1
        assert issubclass(w[-1].category, RuntimeWarning)
        assert "coinciding" in str(w[-1].message)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        pytest.raises(ValueError, plotter.mask_interpolate,
                      coords, np.ones(len(coords)))
        pytest.raises(ValueError, plotter.mask_interpolate,
                      coords, np.ones(2 * len(coords)))


@pytest.mark.parametrize("engine", _plotter.engines)
def test_bands(engine):

    plotter.set_engine(engine)

    syst = syst_2d().finalized().leads[0]

    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        plotter.bands(syst, show=False, file=out_filename)
        plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out_filename)

        if engine == 'matplotlib':
            plotter.bands(syst, show=False, fig_size=(10, 10), file=out_filename)

            fig = pyplot.Figure()
            ax = fig.add_subplot(1, 1, 1)
            plotter.bands(syst, show=False, ax=ax, file=out_filename)



@pytest.mark.parametrize("engine", _plotter.engines)
def test_spectrum(engine):

    plotter.set_engine(engine)

    def ham_1d(a, b, c):
        return a**2 + b**2 + c**2

    def ham_2d(a, b, c):
        return np.eye(2) * (a**2 + b**2 + c**2)

    lat = kwant.lattice.chain(norbs=1)
    syst = kwant.Builder()
    syst[(lat(i) for i in range(3))] = lambda site, a, b: a + b
    syst[lat.neighbors()] = lambda site1, site2, c: c
    fsyst = syst.finalized()

    vals = np.linspace(0, 1, 3)

    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out_filename, show=False)
            if engine == 'matplotlib':
                # test with explicit figsize
                plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
                                 fig_size=(10, 10), file=out_filename, show=False)

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), file=out_filename, show=False)
            if engine == 'matplotlib':
                # test with explicit figsize
                plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                                 params=dict(c=1), fig_size=(10, 10), file=out_filename, show=False)

        if engine == 'matplotlib':
            # test 2D plot and explicitly passing axis
            fig = pyplot.figure()
            ax = fig.add_subplot(1, 1, 1, projection='3d')
            plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), ax=ax, file=out_filename, show=False)
            # explicitly pass axis without 3D support
            ax = fig.add_subplot(1, 1, 1)
            with pytest.raises(TypeError):
                plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                                 params=dict(c=1), ax=ax, file=out_filename, show=False)

    def mask(a, b):
        return a > 0.5

    with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
        out_filename = out.name
        plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
                         mask=mask, file=out_filename, show=False)


def syst_rect(lat, salt, W=3, L=50):
    syst = kwant.Builder()

    ll = L//2
    ww = W//2

    def onsite(site):
        return 4 + 0.1 * kwant.digest.gauss(repr(site.tag), salt=salt)

    syst[(lat(i, j) for i in range(-ll, ll+1)
         for j in range(-ww, ww+1))] = onsite
    syst[lat.neighbors()] = -1

    sym = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
    lead = kwant.Builder(sym)
    lead[(lat(0, j) for j in range(-ww, ww + 1))] = 4
    lead[lat.neighbors()] = -1

    syst.attach_lead(lead)
    syst.attach_lead(lead.reversed())

    return syst


def div(F, h):
    """Calculate the divergence of a vector field F over a grid of spacing h."""
    assert len(F.shape[:-1]) == F.shape[-1]
    assert len(h) == F.shape[-1]
    return sum(np.gradient(F[..., i], h[i])[i] for i in range(F.shape[-1]))


def rotational_currents(g):
    """Return a basis of divergence-free currents for a closed graph.

    Given the graph 'g' of a Kwant system, returns a sequence of arrays
    which are linearly independent, divergence-free currents on the graph.
    """
    #'A' represents the set of expressions that give the net current flow
    # into the system sites. 'perm' is a map from the edges of a graph
    # with only 1 edge per hopping to the proper Kwant graph (2 edges
    # per hopping).
    A = np.zeros((g.num_nodes, g.num_edges // 2))
    hoppings = dict()
    perm_data = np.zeros(g.num_edges, dtype=int)
    perm_ij = np.zeros((2, g.num_edges), dtype=int)
    i = 0
    for k, (a, b) in enumerate(g):
        hop = frozenset((a, b))
        if hop not in hoppings:
            A[a, i] = 1
            A[b, i] = -1
            hoppings[hop] = i
            perm_data[k] = 1
            perm_ij[:, k] = (k, i)
            i += 1
        else:
            perm_data[k] = -1
            perm_ij[:, k] = (k, hoppings[hop])

    perm = scipy.sparse.coo_matrix((perm_data, perm_ij))

    # Get the row vectors of V with singular value 0. These form
    # a basis for the right null space of 'A'.
    U, S, V = np.linalg.svd(A)
    tol = S.max() * max(A.shape) * np.finfo(S.dtype).eps
    rank = sum(S > tol)
    # Transform null space basis into vectors defined over the full
    # hopping space (both hopping directions).
    null_space_basis = V[-(len(V) - rank):].transpose()
    null_space_basis = perm.dot(null_space_basis).transpose()
    return null_space_basis


def _border_is_0(field):
    borders = [(0, slice(None)), (-1, slice(None)),
               (slice(None), 0), (slice(None), -1)]
    return all(np.allclose(field[a, b], 0) for a, b in borders)


def _test_border_0(interpolator):
    ## Test that current is always identically zero at box boundaries
    syst = kwant.Builder()
    lat = kwant.lattice.square(norbs=1)
    syst[[lat(0, 0), lat(1, 0)]] = None
    syst[(lat(0, 0), lat(1, 0))] = None
    syst = syst.finalized()
    values = [1, -1]

    ns = [3, 4, 5, 10, 100]
    abswidths = [0.01, 0.1, 1, 10, 100]
    relwidths = [0.01, 0.1, 1, 10, 100]
    for n, abswidth in itertools.product(ns, abswidths):
        field, _ = interpolator(syst, values, abswidth=abswidth, n=n)
        assert _border_is_0(field)
    for n, relwidth in itertools.product(ns, relwidths):
        field, _ = interpolator(syst, values, relwidth=relwidth, n=n)
        assert _border_is_0(field)


def test_density_interpolation():
    ## Passing a Builder will raise an error
    pytest.raises(TypeError, plotter.interpolate_density, syst_2d(), None)

    # Test that the density is always identically zero at the box boundaries
    # as the bump function has finite support and we add a padding
    _test_border_0(kwant.plotter.interpolate_density)

    def R(theta):
        return ta.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])

    # Make lattice with lattice vectors perturbed from x and y directions
    def make_lattice(a, salt='0'):
        theta_x = kwant.digest.uniform('x', salt=salt) * np.pi / 6
        theta_y = kwant.digest.uniform('y', salt=salt) * np.pi / 6
        x = ta.dot(R(theta_x), (a, 0))
        y = ta.dot(R(theta_y), (0, a))
        return kwant.lattice.general([x, y], norbs=1)

    # Check that integrating the interpolated density gives the same result
    # as summing the densities on all sites. We check this for several lattice
    # widths, lattice orientations and bump widths.
    for a, width in itertools.product((1, 2), (1, 0.5)):
        lat = make_lattice(a)
        syst = syst_rect(lat, salt='0').finalized()

        psi = kwant.wave_function(syst, energy=3)(0)[0]
        density = kwant.operator.Density(syst)(psi)
        exact_charge = sum(density)
        # We verify that the result is good by interpolating for
        # various numbers of points-per-bump and verifying that
        # the error falls of as 1/n.
        data = []
        for n in [4, 6, 8, 11, 16]:
            rho, box = plotter.interpolate_density(syst, density,
                                                   n=n, abswidth=width)
            (xmin, xmax), (ymin, ymax) = box
            area =  xmax - xmin * (ymax - ymin)
            N = rho.shape[0] * rho.shape[1]
            charge = np.sum(rho) * area / N
            data.append((n, abs(charge - exact_charge)))
        _, _, rvalue, *_ = scipy.stats.linregress(np.log(data))
        # Gradient of -1 on log-log plot means error falls off as 1/n
        # TODO: review this value once #280 has been dealt with.
        assert rvalue < -0.7

    # Test that the interpolation is linear in the input.
    rng = ensure_rng(1)
    lat = make_lattice(1, '1')
    syst = syst_rect(lat, salt='1').finalized()
    rho_0 = rng.rand(len(syst.sites))
    rho_1 = rng.rand(len(syst.sites))

    irho_0, _ = plotter.interpolate_density(syst, rho_0)
    irho_1, _ = plotter.interpolate_density(syst, rho_1)

    rho_tot, _ = plotter.interpolate_density(syst, rho_0 + 2 * rho_1)
    assert np.allclose(rho_tot, irho_0 + 2 * irho_1)


def test_current_interpolation():

    ## Passing a Builder will raise an error
    pytest.raises(TypeError, plotter.interpolate_current, syst_2d(), None)

    def R(theta):
        return ta.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])

    def make_lattice(a, theta):
        x = ta.dot(R(theta), (a, 0))
        y = ta.dot(R(theta), (0, a))
        return kwant.lattice.general([x, y], norbs=1)

    _test_border_0(plotter.interpolate_current)

    ## Check current through cross section is same for different lattice
    ## parameters and orientations of the system wrt. the discretization grid
    for a, theta, width in [(1, 0, 1),
                            (1, 0, 0.5),
                            (2, 0, 1),
                            (1, 0.2, 1),
                            (2, 0.4, 1)]:
        lat = make_lattice(a, theta)
        syst = syst_rect(lat, salt='0').finalized()
        psi = kwant.wave_function(syst, energy=3)(0)

        def cut(a, b):
            return b.tag[0] < 0 and a.tag[0] >= 0

        J = kwant.operator.Current(syst).bind()
        J_cut = kwant.operator.Current(syst, where=cut, sum=True).bind()
        J_exact = J_cut(psi[0])

        data = []
        for n in [4, 6, 8, 11, 16]:
            j0, box = plotter.interpolate_current(syst, J(psi[0]),
                                                  n=n, abswidth=width)
            x, y = (np.linspace(mn, mx, shape)
                    for (mn, mx), shape in zip(box, j0.shape))
            # slice field perpendicular to a cut along the y axis
            y_axis = (np.argmin(np.abs(x)), slice(None), 0)
            J_interp = scipy.integrate.simps(j0[y_axis], y)
            data.append((n, abs(J_interp - J_exact)))
        # 3rd value returned from 'linregress' is 'rvalue'
        # TODO: review this value once #280 has been dealt with.
        assert scipy.stats.linregress(np.log(data))[2] < -0.7


    ### Tests on a divergence-free current (closed system)

    lat = kwant.lattice.general([(1, 0), (0.5, np.sqrt(3) / 2)], norbs=1)
    syst = kwant.Builder()
    sites = [lat(0, 0), lat(1, 0), lat(0, 1), lat(2, 2)]
    syst[sites] = None
    syst[((s, t) for s, t in itertools.product(sites, sites) if s != t)] = None
    del syst[lat(0, 0), lat(2, 2)]
    syst = syst.finalized()

    # generate random divergence-free currents
    Js = rotational_currents(syst.graph)
    rng = ensure_rng(3)
    J0 = sum(rng.rand(len(Js))[:, None] * Js)
    J1 = sum(rng.rand(len(Js))[:, None] * Js)

    # Sanity check that diverence on the graph is 0
    divergence = np.zeros(len(syst.sites))
    for (a, _), current in zip(syst.graph, J0):
        divergence[a] += current
    assert np.allclose(divergence, 0)

    j0, _ = plotter.interpolate_current(syst, J0)
    j1, _ = plotter.interpolate_current(syst, J1)

    ## Test linearity of interpolation.
    j_tot, _ = plotter.interpolate_current(syst, J0 + 2 * J1)
    assert np.allclose(j_tot, j0 + 2 * j1)

    ## Test that divergence of interpolated current approaches zero as we make
    ## the interpolation finer.
    data = []
    for n in [4, 6, 8, 11, 16]:
        j, box = plotter.interpolate_current(syst, J0, n=n)
        dx = [(mx - mn) / (shape - 1) for (mn, mx), shape in zip(box, j.shape)]
        div_j = np.max(np.abs(div(j, dx)))
        data.append((n, div_j))

    # 3rd value returned from 'linregress' is 'rvalue'
    # TODO: review this value once #280 has been dealt with.
    assert scipy.stats.linregress(np.log(data))[2] < -0.7


@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_current():
    plotter.set_engine('matplotlib')
    syst = syst_2d().finalized()
    J = kwant.operator.Current(syst)
    current = J(kwant.wave_function(syst, energy=1)(1)[0])

    # Test good codepath
    with tempfile.NamedTemporaryFile('w+b') as out:
        plotter.current(syst, current, file=out)

        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1)
        plotter.current(syst, current, ax=ax, file=out)