Browse code

Refactor common testing components

Joseph Weston authored on 01/11/2023 04:35:17
Showing 1 changed files
... ...
@@ -1,67 +1,21 @@
1
-import string
2
-from functools import partial
3
-
4 1
 from hypothesis import given
5
-from hypothesis.extra import numpy as hnp
6
-import hypothesis.strategies as st
7 2
 import numpy as np
8 3
 import pytest
9 4
 
10 5
 import qsim.state
11 6
 
12
-
13
-MAX_QUBITS = 5
14
-
15
-
16
-def is_unit(c: complex):
17
-    return np.isclose(abs(c), 1)
18
-
19
-
20
-n_qubits = st.integers(1, MAX_QUBITS)
21
-state_dimensions = n_qubits.map(lambda n: 2 ** n)
22
-state_shapes = state_dimensions.map(lambda x: (x,))
23
-classical_bitstrings = n_qubits.flatmap(
24
-    lambda n: st.integers(0, 2 ** n).map(partial("{:0{n}b}".format, n=n))
7
+from .common import (
8
+    classical_bitstrings,
9
+    invalid_bitstrings,
10
+    invalid_states,
11
+    invalid_norm_states,
12
+    n_qubits,
13
+    valid_states,
25 14
 )
26
-invalid_bitstrings = n_qubits.flatmap(
27
-    lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n)
28
-).filter(lambda s: any(b not in "01" for b in s))
29
-
30
-invalid_state_dimensions = st.integers(0, 2 ** MAX_QUBITS).filter(
31
-    lambda n: not np.log2(n).is_integer()
32
-)
33
-
34
-
35
-valid_complex = st.complex_numbers(
36
-    max_magnitude=1e10, allow_infinity=False, allow_nan=False
37
-)
38
-nonunit_complex = valid_complex.filter(lambda c: not is_unit(c))
39
-
40
-
41
-def ket(n_qubits):
42
-    return normalized_array(1 << n_qubits)
43
-
44
-
45
-def normalized_array(size):
46
-    return (
47
-        hnp.arrays(complex, (size,), elements=valid_complex)
48
-        .filter(lambda v: np.linalg.norm(v) > 0)  # vectors must be normalizable
49
-        .map(lambda v: v / np.linalg.norm(v))
50
-    )
51
-
52
-
53
-valid_states = n_qubits.flatmap(ket)
54
-invalid_shape_states = invalid_state_dimensions.flatmap(normalized_array)
55
-invalid_norm_states = valid_states.flatmap(
56
-    lambda x: nonunit_complex.map(lambda c: c * x)
57
-)
58
-zero_states = state_shapes.map(lambda s: np.zeros(s, complex))
59
-invalid_states = st.one_of(invalid_shape_states, invalid_norm_states, zero_states)
60 15
 
61 16
 
62 17
 @given(classical_bitstrings)
63 18
 def test_from_classical(bitstring):
64
-
65 19
     state = qsim.state.from_classical(bitstring)
66 20
 
67 21
     i = int(bitstring, base=2)
... ...
@@ -74,7 +28,6 @@ def test_from_classical(bitstring):
74 28
 
75 29
 @given(classical_bitstrings)
76 30
 def test_from_classical_works_on_integer_lists(bitstring):
77
-
78 31
     int_list = [int(b) for b in bitstring]
79 32
 
80 33
     assert np.all(
... ...
@@ -88,6 +41,11 @@ def test_from_classical_raises_on_bad_input(bitstring):
88 41
         qsim.state.from_classical(bitstring)
89 42
 
90 43
 
44
+@given(n_qubits)
45
+def test_zero(n):
46
+    assert np.array_equal(qsim.state.zero(n), qsim.state.from_classical("0" * n))
47
+
48
+
91 49
 @given(classical_bitstrings)
92 50
 def test_num_qubits(s):
93 51
     state = qsim.state.from_classical(s)
Browse code

Add more constructors and tests to the "state" module

Joseph Weston authored on 15/09/2021 17:23:38
Showing 1 changed files
... ...
@@ -2,6 +2,7 @@ import string
2 2
 from functools import partial
3 3
 
4 4
 from hypothesis import given
5
+from hypothesis.extra import numpy as hnp
5 6
 import hypothesis.strategies as st
6 7
 import numpy as np
7 8
 import pytest
... ...
@@ -11,7 +12,14 @@ import qsim.state
11 12
 
12 13
 MAX_QUBITS = 5
13 14
 
15
+
16
+def is_unit(c: complex):
17
+    return np.isclose(abs(c), 1)
18
+
19
+
14 20
 n_qubits = st.integers(1, MAX_QUBITS)
21
+state_dimensions = n_qubits.map(lambda n: 2 ** n)
22
+state_shapes = state_dimensions.map(lambda x: (x,))
15 23
 classical_bitstrings = n_qubits.flatmap(
16 24
     lambda n: st.integers(0, 2 ** n).map(partial("{:0{n}b}".format, n=n))
17 25
 )
... ...
@@ -19,6 +27,37 @@ invalid_bitstrings = n_qubits.flatmap(
19 27
     lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n)
20 28
 ).filter(lambda s: any(b not in "01" for b in s))
21 29
 
30
+invalid_state_dimensions = st.integers(0, 2 ** MAX_QUBITS).filter(
31
+    lambda n: not np.log2(n).is_integer()
32
+)
33
+
34
+
35
+valid_complex = st.complex_numbers(
36
+    max_magnitude=1e10, allow_infinity=False, allow_nan=False
37
+)
38
+nonunit_complex = valid_complex.filter(lambda c: not is_unit(c))
39
+
40
+
41
+def ket(n_qubits):
42
+    return normalized_array(1 << n_qubits)
43
+
44
+
45
+def normalized_array(size):
46
+    return (
47
+        hnp.arrays(complex, (size,), elements=valid_complex)
48
+        .filter(lambda v: np.linalg.norm(v) > 0)  # vectors must be normalizable
49
+        .map(lambda v: v / np.linalg.norm(v))
50
+    )
51
+
52
+
53
+valid_states = n_qubits.flatmap(ket)
54
+invalid_shape_states = invalid_state_dimensions.flatmap(normalized_array)
55
+invalid_norm_states = valid_states.flatmap(
56
+    lambda x: nonunit_complex.map(lambda c: c * x)
57
+)
58
+zero_states = state_shapes.map(lambda s: np.zeros(s, complex))
59
+invalid_states = st.one_of(invalid_shape_states, invalid_norm_states, zero_states)
60
+
22 61
 
23 62
 @given(classical_bitstrings)
24 63
 def test_from_classical(bitstring):
... ...
@@ -47,3 +86,31 @@ def test_from_classical_works_on_integer_lists(bitstring):
47 86
 def test_from_classical_raises_on_bad_input(bitstring):
48 87
     with pytest.raises(ValueError):
49 88
         qsim.state.from_classical(bitstring)
89
+
90
+
91
+@given(classical_bitstrings)
92
+def test_num_qubits(s):
93
+    state = qsim.state.from_classical(s)
94
+    assert qsim.state.num_qubits(state) == len(s)
95
+
96
+
97
+@given(invalid_states)
98
+def test_num_qubits_raises_exception(state):
99
+    with pytest.raises(ValueError):
100
+        qsim.state.num_qubits(state)
101
+
102
+
103
+@given(invalid_norm_states)
104
+def test_is_not_normalized(state):
105
+    assert not qsim.state.is_normalized(state)
106
+
107
+
108
+@given(valid_states)
109
+def test_is_normalized(state):
110
+    assert qsim.state.is_normalized(state)
111
+
112
+
113
+@given(n_qubits)
114
+def test_zero(n):
115
+    z = qsim.state.zero(n)
116
+    assert qsim.state.num_qubits(z) == n
Browse code

Add a function to produce a state from a classical bitstring

also add relevant tests.

Joseph Weston authored on 09/11/2019 00:48:19
Showing 1 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,49 @@
1
+import string
2
+from functools import partial
3
+
4
+from hypothesis import given
5
+import hypothesis.strategies as st
6
+import numpy as np
7
+import pytest
8
+
9
+import qsim.state
10
+
11
+
12
+MAX_QUBITS = 5
13
+
14
+n_qubits = st.integers(1, MAX_QUBITS)
15
+classical_bitstrings = n_qubits.flatmap(
16
+    lambda n: st.integers(0, 2 ** n).map(partial("{:0{n}b}".format, n=n))
17
+)
18
+invalid_bitstrings = n_qubits.flatmap(
19
+    lambda n: st.text(alphabet=string.digits, min_size=n, max_size=n)
20
+).filter(lambda s: any(b not in "01" for b in s))
21
+
22
+
23
+@given(classical_bitstrings)
24
+def test_from_classical(bitstring):
25
+
26
+    state = qsim.state.from_classical(bitstring)
27
+
28
+    i = int(bitstring, base=2)
29
+
30
+    assert np.issubdtype(state.dtype, np.dtype(complex))
31
+    assert state.shape == (2 ** len(bitstring),)
32
+    assert np.linalg.norm(state) == 1
33
+    assert abs(state[i]) == 1
34
+
35
+
36
+@given(classical_bitstrings)
37
+def test_from_classical_works_on_integer_lists(bitstring):
38
+
39
+    int_list = [int(b) for b in bitstring]
40
+
41
+    assert np.all(
42
+        qsim.state.from_classical(bitstring) == qsim.state.from_classical(int_list)
43
+    )
44
+
45
+
46
+@given(invalid_bitstrings)
47
+def test_from_classical_raises_on_bad_input(bitstring):
48
+    with pytest.raises(ValueError):
49
+        qsim.state.from_classical(bitstring)