import string
from functools import partial

from hypothesis import given
from hypothesis.extra import numpy as hnp
import hypothesis.strategies as st
import numpy as np
import pytest

import qsim.state


MAX_QUBITS = 5


def is_unit(c: complex):
    return np.isclose(abs(c), 1)


n_qubits = st.integers(1, MAX_QUBITS)
state_dimensions = n_qubits.map(lambda n: 2 ** n)
state_shapes = state_dimensions.map(lambda x: (x,))
classical_bitstrings = n_qubits.flatmap(
    lambda n: st.integers(0, 2 ** n).map(partial("{:0{n}b}".format, n=n))
)
invalid_bitstrings = n_qubits.flatmap(
    lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n)
).filter(lambda s: any(b not in "01" for b in s))

invalid_state_dimensions = st.integers(0, 2 ** MAX_QUBITS).filter(
    lambda n: not np.log2(n).is_integer()
)


valid_complex = st.complex_numbers(
    max_magnitude=1e10, allow_infinity=False, allow_nan=False
)
nonunit_complex = valid_complex.filter(lambda c: not is_unit(c))


def ket(n_qubits):
    return normalized_array(1 << n_qubits)


def normalized_array(size):
    return (
        hnp.arrays(complex, (size,), elements=valid_complex)
        .filter(lambda v: np.linalg.norm(v) > 0)  # vectors must be normalizable
        .map(lambda v: v / np.linalg.norm(v))
    )


valid_states = n_qubits.flatmap(ket)
invalid_shape_states = invalid_state_dimensions.flatmap(normalized_array)
invalid_norm_states = valid_states.flatmap(
    lambda x: nonunit_complex.map(lambda c: c * x)
)
zero_states = state_shapes.map(lambda s: np.zeros(s, complex))
invalid_states = st.one_of(invalid_shape_states, invalid_norm_states, zero_states)


@given(classical_bitstrings)
def test_from_classical(bitstring):

    state = qsim.state.from_classical(bitstring)

    i = int(bitstring, base=2)

    assert np.issubdtype(state.dtype, np.dtype(complex))
    assert state.shape == (2 ** len(bitstring),)
    assert np.linalg.norm(state) == 1
    assert abs(state[i]) == 1


@given(classical_bitstrings)
def test_from_classical_works_on_integer_lists(bitstring):

    int_list = [int(b) for b in bitstring]

    assert np.all(
        qsim.state.from_classical(bitstring) == qsim.state.from_classical(int_list)
    )


@given(invalid_bitstrings)
def test_from_classical_raises_on_bad_input(bitstring):
    with pytest.raises(ValueError):
        qsim.state.from_classical(bitstring)


@given(classical_bitstrings)
def test_num_qubits(s):
    state = qsim.state.from_classical(s)
    assert qsim.state.num_qubits(state) == len(s)


@given(invalid_states)
def test_num_qubits_raises_exception(state):
    with pytest.raises(ValueError):
        qsim.state.num_qubits(state)


@given(invalid_norm_states)
def test_is_not_normalized(state):
    assert not qsim.state.is_normalized(state)


@given(valid_states)
def test_is_normalized(state):
    assert qsim.state.is_normalized(state)


@given(n_qubits)
def test_zero(n):
    z = qsim.state.zero(n)
    assert qsim.state.num_qubits(z) == n