from collections import Counter, Sized, Iterable
from itertools import combinations, chain

import numpy as np
import math
import scipy.spatial


def fast_norm(v):
    # notice this method can be even more optimised
    if len(v) == 2:
        return math.sqrt(v[0] * v[0] + v[1] * v[1])
    return math.sqrt(np.dot(v, v))


def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
    (p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
    px, py = point

    area = 0.5 * (- p1y * p2x + p0y * (p2x - p1x)
                  + p1x * p2y + p0x * (p1y - p2y))

    s = 1 / (2 * area) * (+ p0y * p2x + (p2y - p0y) * px
                          - p0x * p2y + (p0x - p2x) * py)
    if s < -eps or s > 1 + eps:
        return False
    t = 1 / (2 * area) * (+ p0x * p1y + (p0y - p1y) * px
                          - p0y * p1x + (p1x - p0x) * py)

    return (t >= -eps) and (s + t <= 1 + eps)


def point_in_simplex(point, simplex, eps=1e-8):
    if len(point) == 2:
        return fast_2d_point_in_simplex(point, simplex, eps)

    x0 = np.array(simplex[0], dtype=float)
    vectors = np.array(simplex[1:], dtype=float) - x0
    alpha = np.linalg.solve(vectors.T, point - x0)

    return all(alpha > -eps) and sum(alpha) < 1 + eps


def fast_2d_circumcircle(points):
    """Compute the center and radius of the circumscribed circle of a triangle

    Parameters
    ----------
    points: 2D array-like
        the points of the triangle to investigate

    Returns
    -------
    tuple
        (center point : tuple(int), radius: int)
    """
    points = np.array(points)
    # transform to relative coordinates
    pts = points[1:] - points[0]

    (x1, y1), (x2, y2) = pts
    # compute the length squared
    l1 = x1 * x1 + y1 * y1
    l2 = x2 * x2 + y2 * y2

    # compute some determinants
    dx = + l1 * y2 - l2 * y1
    dy = - l1 * x2 + l2 * x1
    aa = + x1 * y2 - x2 * y1
    a = 2 * aa

    # compute center
    x = dx / a
    y = dy / a
    radius = math.sqrt(x*x + y*y)  # radius = norm([x, y])

    return (x + points[0][0], y + points[0][1]), radius


def fast_3d_circumcircle(points):
    """Compute the center and radius of the circumscribed shpere of a simplex.

    Parameters
    ----------
    points: 2D array-like
        the points of the triangle to investigate

    Returns
    -------
    tuple
        (center point : tuple(int), radius: int)
    """
    points = np.array(points)
    pts = points[1:] - points[0]

    l1, l2, l3 = [np.dot(p, p) for p in pts]  # length squared
    (x1, y1, z1), (x2, y2, z2), (x3, y3, z3) = pts

    # Compute some determinants:
    dx = (+ l1 * (y2 * z3 - z2 * y3)
          - l2 * (y1 * z3 - z1 * y3)
          + l3 * (y1 * z2 - z1 * y2))
    dy = (+ l1 * (x2 * z3 - z2 * x3)
          - l2 * (x1 * z3 - z1 * x3)
          + l3 * (x1 * z2 - z1 * x2))
    dz = (+ l1 * (x2 * y3 - y2 * x3)
          - l2 * (x1 * y3 - y1 * x3)
          + l3 * (x1 * y2 - y1 * x2))
    aa = (+ x1 * (y2 * z3 - z2 * y3)
          - x2 * (y1 * z3 - z1 * y3)
          + x3 * (y1 * z2 - z1 * y2))
    a = 2 * aa

    center = [dx / a, -dy / a, dz / a]
    radius = fast_norm(center)
    center = np.add(center, points[0])

    return tuple(center), radius


def circumsphere(pts):
    dim = len(pts) - 1
    if dim == 2:
        return fast_2d_circumcircle(pts)
    if dim == 3:
        return fast_3d_circumcircle(pts)

    # Modified method from http://mathworld.wolfram.com/Circumsphere.html
    mat = [[np.sum(np.square(pt)), *pt, 1] for pt in pts]

    center = []
    for i in range(1, len(pts)):
        r = np.delete(mat, i, 1)
        factor = (-1) ** (i + 1)
        center.append(factor * np.linalg.det(r))

    a = np.linalg.det(np.delete(mat, 0, 1))
    center = [x / (2 * a) for x in center]

    x0 = pts[0]
    vec = np.subtract(center, x0)
    radius = fast_norm(vec)

    return tuple(center), radius


def orientation(face, origin):
    """Compute the orientation of the face with respect to a point, origin.

    Parameters
    ----------
    face : array-like, of shape (N-dim, N-dim)
        The hyperplane we want to know the orientation of
        Do notice that the order in which you provide the points is critical
    origin : array-like, point of shape (N-dim)
        The point to compute the orientation from

    Returns
    -------
    0 if the origin lies in the same hyperplane as face,
    -1 or 1 to indicate left or right orientation

    If two points lie on the same side of the face, the orientation will
    be equal, if they lie on the other side of the face, it will be negated.
    """
    vectors = np.array(face)
    sign, logdet = np.linalg.slogdet(vectors - origin)
    if logdet < -50:  # assume it to be zero when it's close to zero
        return 0
    return sign


def is_iterable_and_sized(obj):
    return isinstance(obj, Iterable) and isinstance(obj, Sized)


class Triangulation:
    """A triangulation object.

    Parameters
    ----------
    coords : 2d array-like of floats
        Coordinates of vertices.

    Attributes
    ----------
    vertices : list of float tuples
        Coordinates of the triangulation vertices.
    simplices : set of integer tuples
        List with indices of vertices forming individual simplices
    vertex_to_simplices : list of sets
        Set of simplices connected to a vertex, the index of the vertex is the
        index of the list.
    hull : set of int
        Exterior vertices

    Raises
    ------
    ValueError
        if the list of coordinates is incorrect or the points do not form one 
        or more simplices in the 
    """

    def __init__(self, coords):
        if not is_iterable_and_sized(coords):
            raise TypeError("Please provide a 2-dimensional list of points")
        coords = list(coords)
        if not all(is_iterable_and_sized(coord) for coord in coords):
            raise TypeError("Please provide a 2-dimensional list of points")
        if len(coords) == 0:
            raise ValueError("Please provide at least one simplex") 
            # raise now because otherwise the next line will raise a less

        dim = len(coords[0])
        if any(len(coord) != dim for coord in coords):
            raise ValueError("Coordinates dimension mismatch")

        if dim == 1:
            raise ValueError("Triangulation class only supports dim >= 2")

        if len(coords) < dim + 1:
            raise ValueError("Please provide at least one simplex")

        coords = list(map(tuple, coords))
        vectors = np.subtract(coords[1:], coords[0])
        if np.linalg.matrix_rank(vectors) < dim:
            raise ValueError("Initial simplex has zero volumes "
                             "(the points are linearly dependent)")

        self.vertices = list(coords)
        self.simplices = set()
        # initialise empty set for each vertex
        self.vertex_to_simplices = [set() for _ in coords]

        # find a Delaunay triangulation to start with, then we will throw it 
        # away and continue with our own algorithm
        initial_tri = scipy.spatial.Delaunay(coords)
        for simplex in initial_tri.simplices:
            self.add_simplex(simplex)

    def delete_simplex(self, simplex):
        simplex = tuple(sorted(simplex))
        self.simplices.remove(simplex)
        for vertex in simplex:
            self.vertex_to_simplices[vertex].remove(simplex)

    def add_simplex(self, simplex):
        simplex = tuple(sorted(simplex))
        self.simplices.add(simplex)
        for vertex in simplex:
            self.vertex_to_simplices[vertex].add(simplex)

    def get_vertices(self, indices):
        return [self.vertices[i] for i in indices]

    def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
        """Check whether vertex lies within a simplex.

        Returns
        -------
        vertices : list of ints
            Indices of vertices of the simplex to which the vertex belongs.
            An empty list indicates that the vertex is outside the simplex.
        """
        # XXX: in the end we want to lose this method
        if len(simplex) != self.dim + 1:
            # We are checking whether point belongs to a face.
            simplex = self.containing(simplex).pop()
        x0 = np.array(self.vertices[simplex[0]])
        vectors = np.array(self.get_vertices(simplex[1:])) - x0
        alpha = np.linalg.solve(vectors.T, point - x0)
        if any(alpha < -eps) or sum(alpha) > 1 + eps:
            return []

        result = [i for i, a in enumerate(alpha, 1) if a > eps]
        if sum(alpha) < 1 - eps:
            result.insert(0, 0)

        return [simplex[i] for i in result]

    def point_in_simplex(self, point, simplex, eps=1e-8):
        vertices = self.get_vertices(simplex)
        return point_in_simplex(point, vertices, eps)

    def locate_point(self, point):
        """Find to which simplex the point belongs.

        Return indices of the simplex containing the point.
        Empty tuple means the point is outside the triangulation
        """
        for simplex in self.simplices:
            if self.point_in_simplex(point, simplex):
                return simplex
        return ()

    @property
    def dim(self):
        return len(self.vertices[0])

    def faces(self, dim=None, simplices=None, vertices=None):
        """Iterator over faces of a simplex or vertex sequence."""
        if dim is None:
            dim = self.dim

        if simplices is not None and vertices is not None:
            raise ValueError("Only one of simplices and vertices is allowed.")
        if vertices is not None:
            vertices = set(vertices)
            simplices = chain(*(self.vertex_to_simplices[i] for i in vertices))
            simplices = set(simplices)
        elif simplices is None:
            simplices = self.simplices

        faces = (face for tri in simplices
                 for face in combinations(tri, dim))

        if vertices is not None:
            return (face for face in faces if all(i in vertices for i in face))
        else:
            return faces

    def containing(self, face):
        """Simplices containing a face."""
        return set.intersection(*(self.vertex_to_simplices[i] for i in face))

    def _extend_hull(self, new_vertex, eps=1e-8):
        # count multiplicities in order to get all hull faces
        multiplicities = Counter(face for face in self.faces())
        hull_faces = [face for face, count in multiplicities.items() if count == 1]

        # compute the center of the convex hull, this center lies in the hull
        # we do not really need the center, we only need a point that is
        # guaranteed to lie strictly within the hull
        hull_points = self.get_vertices(self.hull)
        pt_center = np.average(hull_points, axis=0)

        pt_index = len(self.vertices)
        self.vertices.append(new_vertex)

        new_simplices = set()
        for face in hull_faces:
            # do orientation check, if orientation is the same, it lies on
            # the same side of the face, otherwise, it lies on the other
            # side of the face
            pts_face = tuple(self.get_vertices(face))
            orientation_inside = orientation(pts_face, pt_center)
            orientation_new_point = orientation(pts_face, new_vertex)
            if orientation_inside == -orientation_new_point:
                # if the orientation of the new vertex is zero or directed
                # towards the center, do not add the simplex
                self.add_simplex((*face, pt_index))
                new_simplices.add((*face, pt_index))

        if len(new_simplices) == 0:
            # We tried to add an internal point, revert and raise.
            for tri in self.vertex_to_simplices[pt_index]:
                self.simplices.remove(tri)
            del self.vertex_to_simplices[pt_index]
            del self.vertices[pt_index]
            raise ValueError("Candidate vertex is inside the hull.")

        return new_simplices

    def circumscribed_circle(self, simplex, transform):
        """Compute the center and radius of the circumscribed circle of a simplex.

        Parameters
        ----------
        simplex : tuple of ints
            the simplex to investigate

        Returns
        -------
        tuple (center point, radius)
            The center and radius of the circumscribed circle
        """
        pts = np.dot(self.get_vertices(simplex), transform)
        return circumsphere(pts)

    def point_in_cicumcircle(self, pt_index, simplex, transform):
        # return self.fast_point_in_circumcircle(pt_index, simplex, transform)
        eps = 1e-8

        center, radius = self.circumscribed_circle(simplex, transform)
        pt = np.dot(self.get_vertices([pt_index]), transform)[0]

        return np.linalg.norm(center - pt) < (radius * (1 + eps))

    @property
    def default_transform(self):
        return np.eye(self.dim)

    def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
        """Modified Bowyer-Watson point adding algorithm.

        Create a hole in the triangulation around the new point,
        then retriangulate this hole.

        Parameters
        ----------
        pt_index: number
            the index of the point to inspect

        Returns
        -------
        deleted_simplices : set of tuples
            Simplices that have been deleted
        new_simplices : set of tuples
            Simplices that have been added
        """
        queue = set()
        done_simplices = set()

        transform = self.default_transform if transform is None else transform

        if containing_simplex is None:
            queue.update(self.vertex_to_simplices[pt_index])
        else:
            queue.add(containing_simplex)

        done_points = {pt_index}

        bad_triangles = set()

        while len(queue):
            simplex = queue.pop()
            done_simplices.add(simplex)

            if self.point_in_cicumcircle(pt_index, simplex, transform):
                self.delete_simplex(simplex)
                todo_points = set(simplex) - done_points
                done_points.update(simplex)

                if len(todo_points):
                    neighbours = set.union(*[self.vertex_to_simplices[p]
                                             for p in todo_points])
                    queue.update(neighbours - done_simplices)

                bad_triangles.add(simplex)

        faces = list(self.faces(simplices=bad_triangles))

        multiplicities = Counter(face for face in faces)
        hole_faces = [face for face in faces if multiplicities[face] < 2]

        for face in hole_faces:
            if pt_index not in face:
                if self.volume((*face, pt_index)) < 1e-8:
                    continue
                self.add_simplex((*face, pt_index))

        new_triangles = self.vertex_to_simplices[pt_index]
        return bad_triangles - new_triangles, new_triangles - bad_triangles

    def add_point(self, point, simplex=None, transform=None):
        """Add a new vertex and create simplices as appropriate.

        Parameters
        ----------
        point : float vector
            Coordinates of the point to be added.
        transform : N*N matrix of floats
            Multiplication matrix to apply to the point (and neighbouring
            simplices) when running the Bowyer Watson method.
        simplex : tuple of ints, optional
            Simplex containing the point. Empty tuple indicates points outside
            the hull. If not provided, the algorithm costs O(N), so this should
            be used whenever possible.
        """
        point = tuple(point)
        if simplex is None:
            simplex = self.locate_point(point)

        actual_simplex = simplex
        self.vertex_to_simplices.append(set())

        if not simplex:
            temporary_simplices = self._extend_hull(point)

            pt_index = len(self.vertices) - 1
            deleted_simplices, added_simplices = \
                self.bowyer_watson(pt_index, transform=transform)

            deleted = deleted_simplices - temporary_simplices
            added = added_simplices | (temporary_simplices - deleted_simplices)
            return deleted, added
        else:
            reduced_simplex = self.get_reduced_simplex(point, simplex)
            if not reduced_simplex:
                self.vertex_to_simplices.pop() # revert adding vertex
                raise ValueError('Point lies outside of the specified simplex.')
            else:
                simplex = reduced_simplex

        if len(simplex) == 1:
            self.vertex_to_simplices.pop()  # revert adding vertex
            raise ValueError("Point already in triangulation.")
        else:
            pt_index = len(self.vertices)
            self.vertices.append(point)
            return self.bowyer_watson(pt_index, actual_simplex, transform)

    def volume(self, simplex):
        prefactor = np.math.factorial(self.dim)
        vertices = np.array(self.get_vertices(simplex))
        vectors = vertices[1:] - vertices[0]
        return abs(np.linalg.det(vectors)) / prefactor

    def volumes(self):
        return [self.volume(sim) for sim in self.simplices]

    def reference_invariant(self):
        """vertex_to_simplices and simplices are compatible."""
        for vertex in range(len(self.vertices)):
            if any(vertex not in tri
                   for tri in self.vertex_to_simplices[vertex]):
                return False
        for simplex in self.simplices:
            if any(simplex not in self.vertex_to_simplices[pt]
                   for pt in simplex):
                return False
        return True

    def vertex_invariant(self, vertex):
        """Simplices originating from a vertex don't overlap."""
        raise NotImplementedError

    @property
    def hull(self):
        """Compute hull from triangulation.

        Parameters
        ----------
        check : bool, default True
            Whether to raise an error if the computed hull is different from
            stored.

        Returns
        -------
        hull : set of int
            Vertices in the hull.
        """
        counts = Counter(self.faces())
        if any(i > 2 for i in counts.values()):
            raise RuntimeError("Broken triangulation, a (N-1)-dimensional"
                               " appears in more than 2 simplices.")

        hull = set(point for face, count in counts.items() if count == 1
                   for point in face)
        return hull

    def convex_invariant(self, vertex):
        """Hull is convex."""
        raise NotImplementedError