# Copyright 2011-2017 Kwant authors.
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# http://kwant-project.org/license.  A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.

import keyword
from collections import defaultdict

import numpy as np

import sympy
import sympy.abc
import sympy.physics.quantum
from sympy.core.function import AppliedUndef
from sympy.core.sympify import converter
from sympy.core.core import all_classes as sympy_classes
from sympy.physics.matrices import msigma as _msigma

import warnings

from .._common import reraise_warnings

# TODO: remove when sympy correctly includes MutableDenseMatrix (lol).
sympy_classes = set(sympy_classes) | {sympy.MutableDenseMatrix}

momentum_operators = sympy.symbols('k_x k_y k_z', commutative=False)
position_operators = sympy.symbols('x y z', commutative=False)

pauli = [sympy.eye(2), _msigma(1), _msigma(2), _msigma(3)]

extra_ns = sympy.abc._clash.copy()
extra_ns.update({s.name: s for s in momentum_operators})
extra_ns.update({s.name: s for s in position_operators})
extra_ns.update({'kron': sympy.physics.quantum.TensorProduct,
                 'eye': sympy.eye, 'identity': sympy.eye})
extra_ns.update({'sigma_{}'.format(c): p for c, p in zip('0xyz', pauli)})

# workaroud for https://github.com/sympy/sympy/issues/12060
del extra_ns['I']
del extra_ns['pi']

################  Helpers to handle sympy

def lambdify(expr, locals=None):
    """Return a callable object for computing a continuum Hamiltonian.

    .. warning::
        This function uses ``eval`` (because it calls ``sympy.sympify``), and
        thus should not be used on unsanitized input.

    If necessary, the given expression is sympified using
    `kwant.continuum.sympify`.  It is then converted into a callable object.

    expr : str or SymPy expression
        Expression to be converted into a callable object
    locals : dict or ``None`` (default)
        Additional definitions for `~kwant.continuum.sympify`.

    >>> f = lambdify('a + b', locals={'b': 'b + c'})
    >>> f(1, 3, 5)

    >>> ns = {'sigma_plus': [[0, 2], [0, 0]]}
    >>> f = lambdify('k_x**2 * sigma_plus', ns)
    >>> f(0.25)
    array([[ 0.   ,  0.125],
           [ 0.   ,  0.   ]])
    with reraise_warnings(level=4):
        expr = sympify(expr, locals)

    args = [s.name for s in expr.atoms(sympy.Symbol)]
    args += [str(f.func) for f in expr.atoms(AppliedUndef, sympy.Function)]

    return sympy.lambdify(sorted(args), expr)

def sympify(expr, locals=None):
    """Sympify object using special rules for Hamiltonians.

    If `'expr`` is already a type that SymPy understands, it will do nothing
    but return that value. Note that ``locals`` will not be used in this

    Otherwise, it is sympified by ``sympy.sympify`` with a modified namespace
    such that

    * the position operators "x", "y" or "z" and momentum operators "k_x",
      "k_y", and "k_z" do not commute,
    * all single-letter identifiers and names of Greek letters (e.g. "pi" or
      "gamma") are treated as symbols,
    * "kron" corresponds to ``sympy.physics.quantum.TensorProduct``, and
      "identity" to ``sympy.eye``,
    * "sigma_0", "sigma_x", "sigma_y", "sigma_z" are the Pauli matrices.

    In addition, Python list literals are interpreted as SymPy matrices.

    .. warning::
        This function uses ``eval`` (because it calls ``sympy.sympify``), and
        thus should not be used on unsanitized input.

    expr : str or SymPy expression
        Expression to be converted to a SymPy object.
    locals : dict or ``None`` (default)
        Additional entries for the namespace under which `expr` is sympified.
        The keys must be valid Python variable names.  The values may be
        strings, since they are all are sent through `continuum.sympify`
        themselves before use.  (Note that this is a difference to how
        ``sympy.sympify`` behaves.)

        .. note::
            When a value of `locals` is already a SymPy object, it is used
            as-is, and the caller is responsible to set the commutativity of
            its symbols appropriately.  This possible source of errors is
            demonstrated in the last example below.

    result : SymPy object

        >>> sympify('k_x * A(x) * k_x + V(x)')
        k_x*A(x)*k_x + V(x)     # as valid sympy object

        >>> sympify('k_x**2 + V', locals={'V': 'V_0 + V(x)'})
        k_x**2 + V(x) + V_0

        >>> ns = {'sigma_plus': [[0, 2], [0, 0]]}
        >>> sympify('k_x**2 * sigma_plus', ns)
        [0, 2*k_x**2],
        [0,        0]])

        >>> sympify('k_x * A(c) * k_x', locals={'c': 'x'})
        >>> sympify('k_x * A(c) * k_x', locals={'c': sympy.Symbol('x')})

    stored_value = None

    # if ``expr`` is already a ``sympy`` object we may terminate a code path
    if isinstance(expr, tuple(sympy_classes)):
        if locals:
            warnings.warn('Input expression is already SymPy object: '
                          '"locals" will not be used.',

        # We assume that all present functions, like "sin", "cos", will be
        # provided by user during the final evaluation through "params".
        # Therefore we make sure they are defined as AppliedUndef, not built-in
        # sympy types.
        subs = {r: sympy.Function(str(r.func))(*r.args)
                for r in expr.atoms(sympy.Function)}

        return expr.subs(subs)

    # if ``expr`` is not a "sympy" then we proceed with sympifying process
    if locals is None:
        locals = {}

    for k in locals:
        if (not isinstance(k, str)
            or not k.isidentifier() or keyword.iskeyword(k)):
            raise ValueError(
                "Invalid key in 'locals': {}\nKeys must be "
                "identifiers and may not be keywords".format(repr(k)))

    # sympify values of locals before updating it with extra_ns
    # Cast numpy array values in locals to sympy matrices to make sure they have
    # correct format
    locals = {k: (sympy.Matrix(v) if isinstance(v, np.ndarray) else sympify(v))
              for k, v in locals.items()}

    for k, v in extra_ns.items():
        locals.setdefault(k, v)
        stored_value = converter.pop(list, None)
        converter[list] = lambda x: sympy.Matrix(x)
        hamiltonian = sympy.sympify(expr, locals=locals)
        # if input is for example ``[[k_x * A(x) * k_x]]`` after the first
        # sympify we are getting list of sympy objects, so we call sympify
        # second time to obtain ``sympy`` matrices.
        hamiltonian = sympy.sympify(hamiltonian)
        if stored_value is not None:
            converter[list] = stored_value
            del converter[list]

    return hamiltonian

def make_commutative(expr, *symbols):
    """Make sure that specified symbols are defined as commutative.

    expr: sympy.Expr or sympy.Matrix
    symbols: sequace of symbols
        Set of symbols that are requiered to be commutative. It doesn't matter
        of symbol is provided as commutative or not.

    input expression with all specified symbols changed to commutative.
    symbols = [sympy.Symbol(s.name, commutative=False) for s in symbols]
    expr = expr.subs({s: sympy.Symbol(s.name) for s in symbols})
    return expr

def monomials(expr, gens=None):
    """Parse ``expr`` into monomials in the symbols in ``gens``.

    expr: sympy.Expr or sympy.Matrix
        Sympy expression to be parsed into monomials.
    gens: sequence of sympy.Symbol objects or strings (optional)
        Generators of monomials. If unset it will default to all
        symbols used in ``expr``.

    dictionary (generator: monomial)

        >>> expr = kwant.continuum.sympify("A * (x**2 + y) + B * x + C")
        >>> monomials(expr, gens=('x', 'y'))
        {1: C, x: B, x**2: A, y: A}
    if gens is None:
        gens = expr.atoms(sympy.Symbol)
        gens = [sympify(g) for g in gens]

    if not isinstance(expr, sympy.MatrixBase):
        return _expression_monomials(expr, gens)
        output = defaultdict(lambda: sympy.zeros(*expr.shape))
        for (i, j), e in np.ndenumerate(expr):
            mons = _expression_monomials(e, gens)
            for key, val in mons.items():
                output[key][i, j] += val
        return dict(output)

def _expression_monomials(expr, gens):
    """Parse ``expr`` into monomials in the symbols in ``gens``.

    expr: sympy.Expr
        Sympy expr to be parsed.
    gens: sequence of sympy.Symbol
        Generators of monomials.

    dictionary (generator: monomial)
    expr = sympy.expand(expr)
    output = defaultdict(lambda: sympy.Integer(0))
    for summand in expr.as_ordered_terms():
        key = []
        val = []
        for factor in summand.as_ordered_factors():
            symbol, exponent = factor.as_base_exp()
            if symbol in gens:
        output[sympy.Mul(*key)] += sympy.Mul(*val)

    return dict(output)

################ general help functions

def gcd(*args):
    if len(args) == 1:
        return args[0]

    L = list(args)

    while len(L) > 1:
        a = L[len(L) - 2]
        b = L[len(L) - 1]
        L = L[:len(L) - 2]

        while a:
            a, b = b%a, a


    return abs(b)