... | ... |
@@ -10,7 +10,7 @@ by the associated classical bitstring. |
10 | 10 |
|
11 | 11 |
import numpy as np |
12 | 12 |
|
13 |
-__all__ = ["from_classical"] # type: ignore |
|
13 |
+__all__ = ["from_classical", "is_normalized", "num_qubits", "zero"] # type: ignore |
|
14 | 14 |
|
15 | 15 |
|
16 | 16 |
def from_classical(bitstring): |
... | ... |
@@ -28,7 +28,6 @@ def from_classical(bitstring): |
28 | 28 |
The state vector in the computational basis. |
29 | 29 |
Has :math:`2^n` components. |
30 | 30 |
""" |
31 |
- |
|
32 | 31 |
bitstring = "".join(map(str, bitstring)) |
33 | 32 |
n_qubits = len(bitstring) |
34 | 33 |
try: |
... | ... |
@@ -39,3 +38,41 @@ def from_classical(bitstring): |
39 | 38 |
state = np.zeros(1 << n_qubits, dtype=complex) |
40 | 39 |
state[index] = 1 |
41 | 40 |
return state |
41 |
+ |
|
42 |
+ |
|
43 |
+def zero(n: int): |
|
44 |
+ """Return the zero state on 'n' qubits.""" |
|
45 |
+ state = np.zeros(1 << n, dtype=complex) |
|
46 |
+ state[0] = 1 |
|
47 |
+ return state |
|
48 |
+ |
|
49 |
+ |
|
50 |
+def num_qubits(state): |
|
51 |
+ """Return the number of qubits in the state. |
|
52 |
+ |
|
53 |
+ Raises ValueError if 'state' does not have a shape that is |
|
54 |
+ an integer power of 2. |
|
55 |
+ """ |
|
56 |
+ _check_valid_state(state) |
|
57 |
+ return state.shape[0].bit_length() - 1 |
|
58 |
+ |
|
59 |
+ |
|
60 |
+def is_normalized(state: np.ndarray) -> bool: |
|
61 |
+ """Return True if and only if 'state' is normalized.""" |
|
62 |
+ return np.allclose(np.linalg.norm(state), 1) |
|
63 |
+ |
|
64 |
+ |
|
65 |
+def _check_valid_state(state): |
|
66 |
+ if not ( |
|
67 |
+ # is an array |
|
68 |
+ isinstance(state, np.ndarray) |
|
69 |
+ # is complex |
|
70 |
+ and np.issubdtype(state.dtype, np.complex128) |
|
71 |
+ # is square |
|
72 |
+ and len(state.shape) == 1 |
|
73 |
+ # has size 2**n, n > 1 |
|
74 |
+ and np.log2(state.shape[0]).is_integer() |
|
75 |
+ and state.shape[0].bit_length() > 1 |
|
76 |
+ and is_normalized(state) |
|
77 |
+ ): |
|
78 |
+ raise ValueError("State is not valid") |
... | ... |
@@ -2,6 +2,7 @@ import string |
2 | 2 |
from functools import partial |
3 | 3 |
|
4 | 4 |
from hypothesis import given |
5 |
+from hypothesis.extra import numpy as hnp |
|
5 | 6 |
import hypothesis.strategies as st |
6 | 7 |
import numpy as np |
7 | 8 |
import pytest |
... | ... |
@@ -11,7 +12,14 @@ import qsim.state |
11 | 12 |
|
12 | 13 |
MAX_QUBITS = 5 |
13 | 14 |
|
15 |
+ |
|
16 |
+def is_unit(c: complex): |
|
17 |
+ return np.isclose(abs(c), 1) |
|
18 |
+ |
|
19 |
+ |
|
14 | 20 |
n_qubits = st.integers(1, MAX_QUBITS) |
21 |
+state_dimensions = n_qubits.map(lambda n: 2 ** n) |
|
22 |
+state_shapes = state_dimensions.map(lambda x: (x,)) |
|
15 | 23 |
classical_bitstrings = n_qubits.flatmap( |
16 | 24 |
lambda n: st.integers(0, 2 ** n).map(partial("{:0{n}b}".format, n=n)) |
17 | 25 |
) |
... | ... |
@@ -19,6 +27,37 @@ invalid_bitstrings = n_qubits.flatmap( |
19 | 27 |
lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n) |
20 | 28 |
).filter(lambda s: any(b not in "01" for b in s)) |
21 | 29 |
|
30 |
+invalid_state_dimensions = st.integers(0, 2 ** MAX_QUBITS).filter( |
|
31 |
+ lambda n: not np.log2(n).is_integer() |
|
32 |
+) |
|
33 |
+ |
|
34 |
+ |
|
35 |
+valid_complex = st.complex_numbers( |
|
36 |
+ max_magnitude=1e10, allow_infinity=False, allow_nan=False |
|
37 |
+) |
|
38 |
+nonunit_complex = valid_complex.filter(lambda c: not is_unit(c)) |
|
39 |
+ |
|
40 |
+ |
|
41 |
+def ket(n_qubits): |
|
42 |
+ return normalized_array(1 << n_qubits) |
|
43 |
+ |
|
44 |
+ |
|
45 |
+def normalized_array(size): |
|
46 |
+ return ( |
|
47 |
+ hnp.arrays(complex, (size,), elements=valid_complex) |
|
48 |
+ .filter(lambda v: np.linalg.norm(v) > 0) # vectors must be normalizable |
|
49 |
+ .map(lambda v: v / np.linalg.norm(v)) |
|
50 |
+ ) |
|
51 |
+ |
|
52 |
+ |
|
53 |
+valid_states = n_qubits.flatmap(ket) |
|
54 |
+invalid_shape_states = invalid_state_dimensions.flatmap(normalized_array) |
|
55 |
+invalid_norm_states = valid_states.flatmap( |
|
56 |
+ lambda x: nonunit_complex.map(lambda c: c * x) |
|
57 |
+) |
|
58 |
+zero_states = state_shapes.map(lambda s: np.zeros(s, complex)) |
|
59 |
+invalid_states = st.one_of(invalid_shape_states, invalid_norm_states, zero_states) |
|
60 |
+ |
|
22 | 61 |
|
23 | 62 |
@given(classical_bitstrings) |
24 | 63 |
def test_from_classical(bitstring): |
... | ... |
@@ -47,3 +86,31 @@ def test_from_classical_works_on_integer_lists(bitstring): |
47 | 86 |
def test_from_classical_raises_on_bad_input(bitstring): |
48 | 87 |
with pytest.raises(ValueError): |
49 | 88 |
qsim.state.from_classical(bitstring) |
89 |
+ |
|
90 |
+ |
|
91 |
+@given(classical_bitstrings) |
|
92 |
+def test_num_qubits(s): |
|
93 |
+ state = qsim.state.from_classical(s) |
|
94 |
+ assert qsim.state.num_qubits(state) == len(s) |
|
95 |
+ |
|
96 |
+ |
|
97 |
+@given(invalid_states) |
|
98 |
+def test_num_qubits_raises_exception(state): |
|
99 |
+ with pytest.raises(ValueError): |
|
100 |
+ qsim.state.num_qubits(state) |
|
101 |
+ |
|
102 |
+ |
|
103 |
+@given(invalid_norm_states) |
|
104 |
+def test_is_not_normalized(state): |
|
105 |
+ assert not qsim.state.is_normalized(state) |
|
106 |
+ |
|
107 |
+ |
|
108 |
+@given(valid_states) |
|
109 |
+def test_is_normalized(state): |
|
110 |
+ assert qsim.state.is_normalized(state) |
|
111 |
+ |
|
112 |
+ |
|
113 |
+@given(n_qubits) |
|
114 |
+def test_zero(n): |
|
115 |
+ z = qsim.state.zero(n) |
|
116 |
+ assert qsim.state.num_qubits(z) == n |