Browse code

Refactor common testing components

Joseph Weston authored on 01/11/2023 04:35:17
Showing 3 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,111 @@
1
+from functools import partial
2
+import string
3
+
4
+import hypothesis.strategies as st
5
+from hypothesis.extra import numpy as hnp
6
+import numpy as np
7
+
8
+# Numbers
9
+
10
+
11
+def is_unit(c: complex) -> bool:
12
+    return np.isclose(abs(c), 1)
13
+
14
+
15
+MAX_QUBITS = 6
16
+
17
+phases = st.floats(
18
+    min_value=0, max_value=2 * np.pi, allow_nan=False, allow_infinity=False
19
+)
20
+valid_complex = st.complex_numbers(
21
+    max_magnitude=1e10, allow_infinity=False, allow_nan=False
22
+)
23
+unit_complex = phases.map(lambda p: np.exp(1j * p))
24
+nonunit_complex = valid_complex.filter(lambda c: not is_unit(c))
25
+
26
+n_qubits = st.shared(st.integers(min_value=1, max_value=MAX_QUBITS))
27
+
28
+rng = st.integers(0, 2**32 - 1).map(lambda n: np.random.default_rng(n))
29
+
30
+
31
+# Gates
32
+
33
+
34
+# Choose which qubits from 'n_qubits' to operate on with a gate that
35
+# operates on 'gate_size' qubits
36
+def select_n_qubits(gate_size: int):
37
+    """Return a function that, given 'n_qubits', select 'gate_size' qubits from 0-'n_qubits'."""
38
+
39
+    def _strat(n_qubits):
40
+        assert n_qubits >= gate_size
41
+        possible_qubits = st.integers(0, n_qubits - 1)
42
+        return st.lists(
43
+            possible_qubits, min_size=gate_size, max_size=gate_size, unique=True
44
+        ).map(tuple)
45
+
46
+    return _strat
47
+
48
+
49
+def unitary(n_qubits):
50
+    """Return a strategy for generating unitary matrices of size 2**n_qubits."""
51
+    size = 1 << n_qubits
52
+    return (
53
+        hnp.arrays(complex, (size, size), elements=valid_complex)
54
+        .map(lambda a: np.linalg.qr(a)[0])
55
+        .filter(lambda u: np.all(np.isfinite(u)))
56
+    )
57
+
58
+
59
+single_qubit_gates = unitary(1)
60
+two_qubit_gates = unitary(2)
61
+n_qubit_gates = n_qubits.flatmap(unitary)
62
+
63
+
64
+# States
65
+
66
+
67
+def ket(n_qubits):
68
+    return normalized_array(1 << n_qubits)
69
+
70
+
71
+def is_normalizable_to_full_precision(v):
72
+    # If the squared norm is > 0 but sub-normal
73
+    # (https://en.wikipedia.org/wiki/Subnormal_number)
74
+    # then normalizing the vector with v/|v| will not yield a vector
75
+    # that is normalized to machine precision.
76
+    return np.linalg.norm(v) ** 2 > np.finfo(float).smallest_normal
77
+
78
+
79
+def normalized_array(size):
80
+    return (
81
+        hnp.arrays(complex, (size,), elements=valid_complex)
82
+        .filter(is_normalizable_to_full_precision)  # vectors must be normalizable
83
+        .map(lambda v: v / np.linalg.norm(v))
84
+    )
85
+
86
+
87
+state_dimensions = n_qubits.map(lambda n: 2**n)
88
+state_shapes = state_dimensions.map(lambda x: (x,))
89
+
90
+valid_states = n_qubits.flatmap(ket)
91
+zero_states = state_shapes.map(lambda s: np.zeros(s, complex))
92
+
93
+invalid_state_dimensions = st.integers(0, 2**MAX_QUBITS).filter(
94
+    lambda n: not np.log2(n).is_integer()
95
+)
96
+invalid_shape_states = invalid_state_dimensions.flatmap(normalized_array)
97
+invalid_norm_states = valid_states.flatmap(
98
+    lambda x: nonunit_complex.map(lambda c: c * x)
99
+)
100
+
101
+invalid_states = st.one_of(invalid_shape_states, invalid_norm_states, zero_states)
102
+
103
+
104
+# Bitstrings
105
+
106
+classical_bitstrings = n_qubits.flatmap(
107
+    lambda n: st.integers(0, 2**n).map(partial("{:0{n}b}".format, n=n))
108
+)
109
+invalid_bitstrings = n_qubits.flatmap(
110
+    lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n)
111
+).filter(lambda s: any(b not in "01" for b in s))
... ...
@@ -2,74 +2,30 @@ from functools import reduce
2 2
 
3 3
 from hypothesis import given
4 4
 import hypothesis.strategies as st
5
-import hypothesis.extra.numpy as hnp
6 5
 import numpy as np
7 6
 import pytest
8 7
 
9 8
 import qsim.gate
10 9
 
11
-
12
-# -- Strategies for generating values --
13
-
14
-
15
-n_qubits = st.shared(st.integers(min_value=1, max_value=6))
16
-
17
-
18
-# Choose which qubits from 'n_qubits' to operate on with a gate that
19
-# operates on 'gate_size' qubits
20
-def select_n_qubits(gate_size):
21
-    def _strat(n_qubits):
22
-        assert n_qubits >= gate_size
23
-        possible_qubits = st.integers(0, n_qubits - 1)
24
-        return st.lists(
25
-            possible_qubits, min_size=gate_size, max_size=gate_size, unique=True
26
-        ).map(tuple)
27
-
28
-    return _strat
29
-
30
-
31
-valid_complex = st.complex_numbers(
32
-    max_magnitude=1e10, allow_infinity=False, allow_nan=False
33
-)
34
-phases = st.floats(
35
-    min_value=0, max_value=2 * np.pi, allow_nan=False, allow_infinity=False
10
+from .common import (
11
+    ket,
12
+    single_qubit_gates,
13
+    n_qubit_gates,
14
+    valid_states,
15
+    n_qubits,
16
+    phases,
17
+    select_n_qubits,
36 18
 )
37 19
 
38 20
 
39
-def unitary(n_qubits):
40
-    size = 1 << n_qubits
41
-    return (
42
-        hnp.arrays(complex, (size, size), elements=valid_complex)
43
-        .map(lambda a: np.linalg.qr(a)[0])
44
-        .filter(lambda u: np.all(np.isfinite(u)))
45
-    )
46
-
47
-
48
-def ket(n_qubits):
49
-    size = 1 << n_qubits
50
-    return (
51
-        hnp.arrays(complex, (size,), elements=valid_complex)
52
-        .filter(lambda v: np.linalg.norm(v) > 0)  # vectors must be normalizable
53
-        .map(lambda v: v / np.linalg.norm(v))
54
-    )
55
-
56
-
57
-single_qubit_gates = unitary(1)
58
-two_qubit_gates = unitary(2)
59
-n_qubit_gates = n_qubits.flatmap(unitary)
60
-
61
-# Projectors on the single qubit computational basis
62
-project_zero = np.array([[1, 0], [0, 0]])
63
-project_one = np.array([[0, 0], [0, 1]])
64
-
65
-
66 21
 def product_gate(single_qubit_gates):
67 22
     # We reverse so that 'single_qubit_gates' can be indexed by the qubit
68 23
     # identifier; e.g. qubit #0 is actually the least-significant qubit
69 24
     return reduce(np.kron, reversed(single_qubit_gates))
70 25
 
71 26
 
72
-# -- Tests --
27
+project_zero = np.array([[1, 0], [0, 0]])
28
+project_one = np.array([[0, 0], [0, 1]])
73 29
 
74 30
 
75 31
 @given(n_qubits, n_qubit_gates)
... ...
@@ -145,7 +101,7 @@ def test_swap():
145 101
     assert np.all(qsim.gate.swap @ qsim.gate.swap == np.identity(4))
146 102
 
147 103
 
148
-@given(single_qubit_gates, n_qubits.flatmap(ket), n_qubits.flatmap(select_n_qubits(1)))
104
+@given(single_qubit_gates, valid_states, n_qubits.flatmap(select_n_qubits(1)))
149 105
 def test_applying_single_gates(gate, state, selected):
150 106
     (qubit,) = selected
151 107
     n_qubits = state.shape[0].bit_length() - 1
... ...
@@ -167,11 +123,13 @@ def test_applying_single_gates(gate, state, selected):
167 123
 def test_applying_controlled_single_qubit_gates(gate, state, selected):
168 124
     control, qubit = selected
169 125
     n_qubits = state.shape[0].bit_length() - 1
170
-    # When control qubit is |0⟩ the controlled gate acts like the identity on the other qubit
126
+    # When control qubit is |0⟩ the controlled gate acts
127
+    # like the identity on the other qubit
171 128
     parts_zero = [np.identity(2)] * n_qubits
172 129
     parts_zero[control] = project_zero
173 130
     parts_zero[qubit] = np.identity(2)
174
-    # When control qubit is |1⟩ the controlled gate acts like the original gate on the other qubit
131
+    # When control qubit is |1⟩ the controlled gate acts
132
+    # like the original gate on the other qubit
175 133
     parts_one = [np.identity(2)] * n_qubits
176 134
     parts_one[control] = project_one
177 135
     parts_one[qubit] = gate
... ...
@@ -1,67 +1,21 @@
1
-import string
2
-from functools import partial
3
-
4 1
 from hypothesis import given
5
-from hypothesis.extra import numpy as hnp
6
-import hypothesis.strategies as st
7 2
 import numpy as np
8 3
 import pytest
9 4
 
10 5
 import qsim.state
11 6
 
12
-
13
-MAX_QUBITS = 5
14
-
15
-
16
-def is_unit(c: complex):
17
-    return np.isclose(abs(c), 1)
18
-
19
-
20
-n_qubits = st.integers(1, MAX_QUBITS)
21
-state_dimensions = n_qubits.map(lambda n: 2 ** n)
22
-state_shapes = state_dimensions.map(lambda x: (x,))
23
-classical_bitstrings = n_qubits.flatmap(
24
-    lambda n: st.integers(0, 2 ** n).map(partial("{:0{n}b}".format, n=n))
7
+from .common import (
8
+    classical_bitstrings,
9
+    invalid_bitstrings,
10
+    invalid_states,
11
+    invalid_norm_states,
12
+    n_qubits,
13
+    valid_states,
25 14
 )
26
-invalid_bitstrings = n_qubits.flatmap(
27
-    lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n)
28
-).filter(lambda s: any(b not in "01" for b in s))
29
-
30
-invalid_state_dimensions = st.integers(0, 2 ** MAX_QUBITS).filter(
31
-    lambda n: not np.log2(n).is_integer()
32
-)
33
-
34
-
35
-valid_complex = st.complex_numbers(
36
-    max_magnitude=1e10, allow_infinity=False, allow_nan=False
37
-)
38
-nonunit_complex = valid_complex.filter(lambda c: not is_unit(c))
39
-
40
-
41
-def ket(n_qubits):
42
-    return normalized_array(1 << n_qubits)
43
-
44
-
45
-def normalized_array(size):
46
-    return (
47
-        hnp.arrays(complex, (size,), elements=valid_complex)
48
-        .filter(lambda v: np.linalg.norm(v) > 0)  # vectors must be normalizable
49
-        .map(lambda v: v / np.linalg.norm(v))
50
-    )
51
-
52
-
53
-valid_states = n_qubits.flatmap(ket)
54
-invalid_shape_states = invalid_state_dimensions.flatmap(normalized_array)
55
-invalid_norm_states = valid_states.flatmap(
56
-    lambda x: nonunit_complex.map(lambda c: c * x)
57
-)
58
-zero_states = state_shapes.map(lambda s: np.zeros(s, complex))
59
-invalid_states = st.one_of(invalid_shape_states, invalid_norm_states, zero_states)
60 15
 
61 16
 
62 17
 @given(classical_bitstrings)
63 18
 def test_from_classical(bitstring):
64
-
65 19
     state = qsim.state.from_classical(bitstring)
66 20
 
67 21
     i = int(bitstring, base=2)
... ...
@@ -74,7 +28,6 @@ def test_from_classical(bitstring):
74 28
 
75 29
 @given(classical_bitstrings)
76 30
 def test_from_classical_works_on_integer_lists(bitstring):
77
-
78 31
     int_list = [int(b) for b in bitstring]
79 32
 
80 33
     assert np.all(
... ...
@@ -88,6 +41,11 @@ def test_from_classical_raises_on_bad_input(bitstring):
88 41
         qsim.state.from_classical(bitstring)
89 42
 
90 43
 
44
+@given(n_qubits)
45
+def test_zero(n):
46
+    assert np.array_equal(qsim.state.zero(n), qsim.state.from_classical("0" * n))
47
+
48
+
91 49
 @given(classical_bitstrings)
92 50
 def test_num_qubits(s):
93 51
     state = qsim.state.from_classical(s)