Browse code

Add more constructors and tests to the "state" module

Joseph Weston authored on 15/09/2021 17:23:38
Showing 2 changed files
... ...
@@ -10,7 +10,7 @@ by the associated classical bitstring.
10 10
 
11 11
 import numpy as np
12 12
 
13
-__all__ = ["from_classical"]  # type: ignore
13
+__all__ = ["from_classical", "is_normalized", "num_qubits", "zero"]  # type: ignore
14 14
 
15 15
 
16 16
 def from_classical(bitstring):
... ...
@@ -28,7 +28,6 @@ def from_classical(bitstring):
28 28
         The state vector in the computational basis.
29 29
         Has :math:`2^n` components.
30 30
     """
31
-
32 31
     bitstring = "".join(map(str, bitstring))
33 32
     n_qubits = len(bitstring)
34 33
     try:
... ...
@@ -39,3 +38,41 @@ def from_classical(bitstring):
39 38
     state = np.zeros(1 << n_qubits, dtype=complex)
40 39
     state[index] = 1
41 40
     return state
41
+
42
+
43
+def zero(n: int):
44
+    """Return the zero state on 'n' qubits."""
45
+    state = np.zeros(1 << n, dtype=complex)
46
+    state[0] = 1
47
+    return state
48
+
49
+
50
+def num_qubits(state):
51
+    """Return the number of qubits in the state.
52
+
53
+    Raises ValueError if 'state' does not have a shape that is
54
+    an integer power of 2.
55
+    """
56
+    _check_valid_state(state)
57
+    return state.shape[0].bit_length() - 1
58
+
59
+
60
+def is_normalized(state: np.ndarray) -> bool:
61
+    """Return True if and only if 'state' is normalized."""
62
+    return np.allclose(np.linalg.norm(state), 1)
63
+
64
+
65
+def _check_valid_state(state):
66
+    if not (
67
+        # is an array
68
+        isinstance(state, np.ndarray)
69
+        # is complex
70
+        and np.issubdtype(state.dtype, np.complex128)
71
+        # is square
72
+        and len(state.shape) == 1
73
+        # has size 2**n, n > 1
74
+        and np.log2(state.shape[0]).is_integer()
75
+        and state.shape[0].bit_length() > 1
76
+        and is_normalized(state)
77
+    ):
78
+        raise ValueError("State is not valid")
... ...
@@ -2,6 +2,7 @@ import string
2 2
 from functools import partial
3 3
 
4 4
 from hypothesis import given
5
+from hypothesis.extra import numpy as hnp
5 6
 import hypothesis.strategies as st
6 7
 import numpy as np
7 8
 import pytest
... ...
@@ -11,7 +12,14 @@ import qsim.state
11 12
 
12 13
 MAX_QUBITS = 5
13 14
 
15
+
16
+def is_unit(c: complex):
17
+    return np.isclose(abs(c), 1)
18
+
19
+
14 20
 n_qubits = st.integers(1, MAX_QUBITS)
21
+state_dimensions = n_qubits.map(lambda n: 2 ** n)
22
+state_shapes = state_dimensions.map(lambda x: (x,))
15 23
 classical_bitstrings = n_qubits.flatmap(
16 24
     lambda n: st.integers(0, 2 ** n).map(partial("{:0{n}b}".format, n=n))
17 25
 )
... ...
@@ -19,6 +27,37 @@ invalid_bitstrings = n_qubits.flatmap(
19 27
     lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n)
20 28
 ).filter(lambda s: any(b not in "01" for b in s))
21 29
 
30
+invalid_state_dimensions = st.integers(0, 2 ** MAX_QUBITS).filter(
31
+    lambda n: not np.log2(n).is_integer()
32
+)
33
+
34
+
35
+valid_complex = st.complex_numbers(
36
+    max_magnitude=1e10, allow_infinity=False, allow_nan=False
37
+)
38
+nonunit_complex = valid_complex.filter(lambda c: not is_unit(c))
39
+
40
+
41
+def ket(n_qubits):
42
+    return normalized_array(1 << n_qubits)
43
+
44
+
45
+def normalized_array(size):
46
+    return (
47
+        hnp.arrays(complex, (size,), elements=valid_complex)
48
+        .filter(lambda v: np.linalg.norm(v) > 0)  # vectors must be normalizable
49
+        .map(lambda v: v / np.linalg.norm(v))
50
+    )
51
+
52
+
53
+valid_states = n_qubits.flatmap(ket)
54
+invalid_shape_states = invalid_state_dimensions.flatmap(normalized_array)
55
+invalid_norm_states = valid_states.flatmap(
56
+    lambda x: nonunit_complex.map(lambda c: c * x)
57
+)
58
+zero_states = state_shapes.map(lambda s: np.zeros(s, complex))
59
+invalid_states = st.one_of(invalid_shape_states, invalid_norm_states, zero_states)
60
+
22 61
 
23 62
 @given(classical_bitstrings)
24 63
 def test_from_classical(bitstring):
... ...
@@ -47,3 +86,31 @@ def test_from_classical_works_on_integer_lists(bitstring):
47 86
 def test_from_classical_raises_on_bad_input(bitstring):
48 87
     with pytest.raises(ValueError):
49 88
         qsim.state.from_classical(bitstring)
89
+
90
+
91
+@given(classical_bitstrings)
92
+def test_num_qubits(s):
93
+    state = qsim.state.from_classical(s)
94
+    assert qsim.state.num_qubits(state) == len(s)
95
+
96
+
97
+@given(invalid_states)
98
+def test_num_qubits_raises_exception(state):
99
+    with pytest.raises(ValueError):
100
+        qsim.state.num_qubits(state)
101
+
102
+
103
+@given(invalid_norm_states)
104
+def test_is_not_normalized(state):
105
+    assert not qsim.state.is_normalized(state)
106
+
107
+
108
+@given(valid_states)
109
+def test_is_normalized(state):
110
+    assert qsim.state.is_normalized(state)
111
+
112
+
113
+@given(n_qubits)
114
+def test_zero(n):
115
+    z = qsim.state.zero(n)
116
+    assert qsim.state.num_qubits(z) == n