Browse code

Refactor gate module into operator module

This will be necessary for defining measurements also

Joseph Weston authored on 15/09/2021 17:22:20
Showing 3 changed files
... ...
@@ -6,6 +6,8 @@ matrix written in the computational basis.
6 6
 
7 7
 import numpy as np
8 8
 
9
+from . import operator
10
+
9 11
 __all__ = [
10 12
     "apply",
11 13
     "n_qubits",
... ...
@@ -34,7 +36,8 @@ def apply(gate, qubits, state):
34 36
 
35 37
     Parameters
36 38
     ----------
37
-    gate : ndarray[complex]
39
+    gate: ndarray[complex]
40
+        The gate to apply.
38 41
     qubits : sequence of int
39 42
         The qubits on which to act. Qubit 0 is the least significant qubit.
40 43
     state : ndarray[complex]
... ...
@@ -43,62 +46,16 @@ def apply(gate, qubits, state):
43 46
     -------
44 47
     new_state : ndarray[complex]
45 48
     """
46
-    n_gate_qubits = gate.shape[0].bit_length() - 1
47
-    n_state_qubits = state.shape[0].bit_length() - 1
48
-    assert len(qubits) == n_gate_qubits
49
-
50
-    # We can view an n-qubit gate as a 2*n-tensor (n contravariant and n contravariant
51
-    # indices) and an n-qubit state as an n-tensor (contravariant indices)
52
-    # with each axis having length 2 (the state space of a single qubit).
53
-    gate = gate.reshape((2,) * 2 * n_gate_qubits)
54
-    state = state.reshape((2,) * n_state_qubits)
55
-
56
-    # Our qubits are labeled from least significant to most significant, i.e. our
57
-    # computational basis (for e.g. 2 qubits) is ordered like |00⟩, |01⟩, |10⟩, |11⟩.
58
-    # We represent the state as a tensor in *row-major* order; this means that the
59
-    # axis ordering is *backwards* compared to the qubit ordering (the least significant
60
-    # qubit corresponds to the *last* axis in the tensor etc.)
61
-    qubit_axes = tuple(n_state_qubits - 1 - np.asarray(qubits))
62
-
63
-    # Applying the gate to the state vector is then the tensor product over the appropriate axes
64
-    axes = (np.arange(n_gate_qubits, 2 * n_gate_qubits), qubit_axes)
65
-    new_state = np.tensordot(gate, state, axes=axes)
66
-
67
-    # tensordot effectively re-orders the qubits such that the qubits we operated
68
-    # on are in the most-significant positions (i.e. their axes come first). We
69
-    # thus need to transpose the axes to place them back into their original positions.
70
-    untouched_axes = tuple(
71
-        idx for idx in range(n_state_qubits) if idx not in qubit_axes
72
-    )
73
-    inverse_permutation = np.argsort(qubit_axes + untouched_axes)
74
-    return np.transpose(new_state, inverse_permutation).reshape(-1)
49
+    _check_valid_gate(gate)
50
+    return operator.apply(gate, qubits, state)
51
+
52
+
53
+n_qubits = operator.n_qubits
75 54
 
76 55
 
77 56
 def _check_valid_gate(gate):
78
-    if not (
79
-        # is an array
80
-        isinstance(gate, np.ndarray)
81
-        # is complex
82
-        and np.issubdtype(gate.dtype, np.complex128)
83
-        # is square
84
-        and gate.shape[0] == gate.shape[1]
85
-        # has size 2**n, n > 1
86
-        and np.log2(gate.shape[0]).is_integer()
87
-        and gate.shape[0].bit_length() > 1
88
-        # is unitary
89
-        and np.allclose(gate @ gate.conjugate().transpose(), np.identity(gate.shape[0]))
90
-    ):
91
-        raise ValueError("Gate is not valid")
92
-
93
-
94
-def n_qubits(gate):
95
-    """Return the number of qubits that a gate acts on.
96
-
97
-    Raises ValueError if 'gate' does not have a shape that is
98
-    an integer power of 2.
99
-    """
100
-    _check_valid_gate(gate)
101
-    return gate.shape[0].bit_length() - 1
57
+    if not operator.is_unitary(gate):
58
+        raise ValueError("Gate is invalid")
102 59
 
103 60
 
104 61
 def controlled(gate):
105 62
new file mode 100644
... ...
@@ -0,0 +1,126 @@
1
+import numpy as np
2
+
3
+from .state import num_qubits
4
+
5
+__all__ = ["apply", "is_hermitian", "is_unitary", "is_valid", "n_qubits"]
6
+
7
+
8
+def apply(op, qubits, state):
9
+    """Apply an operator to the specified qubits of a state
10
+
11
+    Parameters
12
+    ----------
13
+    op: ndarray[complex]
14
+        The operator to apply.
15
+    qubits : sequence of int
16
+        The qubits on which to act. Qubit 0 is the least significant qubit.
17
+    state : ndarray[complex]
18
+
19
+    Returns
20
+    -------
21
+    new_state : ndarray[complex]
22
+    """
23
+    _check_apply(op, qubits, state)
24
+
25
+    n_op_qubits = n_qubits(op)
26
+    n_state_qubits = num_qubits(state)
27
+
28
+    # We can view an n-qubit op as a 2*n-tensor (n contravariant and n contravariant
29
+    # indices) and an n-qubit state as an n-tensor (contravariant indices)
30
+    # with each axis having length 2 (the state space of a single qubit).
31
+    op = op.reshape((2,) * 2 * n_op_qubits)
32
+    state = state.reshape((2,) * n_state_qubits)
33
+
34
+    # Our qubits are labeled from least significant to most significant, i.e. our
35
+    # computational basis (for e.g. 2 qubits) is ordered like |00⟩, |01⟩, |10⟩, |11⟩.
36
+    # We represent the state as a tensor in *row-major* order; this means that the
37
+    # axis ordering is *backwards* compared to the qubit ordering (the least significant
38
+    # qubit corresponds to the *last* axis in the tensor etc.)
39
+    qubit_axes = tuple(n_state_qubits - 1 - np.asarray(qubits))
40
+
41
+    # Applying the op to the state vector is then the tensor product over the appropriate axes
42
+    axes = (np.arange(n_op_qubits, 2 * n_op_qubits), qubit_axes)
43
+    new_state = np.tensordot(op, state, axes=axes)
44
+
45
+    # tensordot effectively re-orders the qubits such that the qubits we operated
46
+    # on are in the most-significant positions (i.e. their axes come first). We
47
+    # thus need to transpose the axes to place them back into their original positions.
48
+    untouched_axes = tuple(
49
+        idx for idx in range(n_state_qubits) if idx not in qubit_axes
50
+    )
51
+    inverse_permutation = np.argsort(qubit_axes + untouched_axes)
52
+    return np.transpose(new_state, inverse_permutation).reshape(-1)
53
+
54
+
55
+def _all_distinct(elements):
56
+    if not elements:
57
+        return True
58
+    elements = iter(elements)
59
+    fst = next(elements)
60
+    return all(fst != x for x in elements)
61
+
62
+
63
+def _check_apply(op, qubits, state):
64
+    if not _all_distinct(qubits):
65
+        raise ValueError("Cannot apply an operator to repeated qubits.")
66
+
67
+    n_op_qubits = n_qubits(op)
68
+    if n_op_qubits != len(qubits):
69
+        raise ValueError(
70
+            f"Cannot apply an {n_op_qubits}-qubit operator to {len(qubits)} qubits."
71
+        )
72
+
73
+    n_state_qubits = num_qubits(state)
74
+
75
+    if n_op_qubits > n_state_qubits:
76
+        raise ValueError(
77
+            f"Cannot apply an {n_op_qubits}-qubit operator "
78
+            f"to an {n_state_qubits}-qubit state."
79
+        )
80
+
81
+    invalid_qubits = [q for q in qubits if q >= n_state_qubits]
82
+    if invalid_qubits:
83
+        raise ValueError(
84
+            f"Cannot apply operator to qubits {invalid_qubits} "
85
+            f"of an {n_state_qubits}-qubit state."
86
+        )
87
+
88
+
89
+def is_hermitian(op: np.ndarray) -> bool:
90
+    """Return True if and only if 'op' is a valid Hermitian operator."""
91
+    return is_valid(op) and np.allclose(op, op.conj().T)
92
+
93
+
94
+def is_unitary(op: np.ndarray) -> bool:
95
+    """Return True if and only if 'op' is a valid unitary operator."""
96
+    return is_valid(op) and np.allclose(op @ op.conj().T, np.identity(op.shape[0]))
97
+
98
+
99
+def is_valid(op: np.ndarray) -> bool:
100
+    """Return True if and only if 'op' is a valid operator."""
101
+    return (
102
+        # is an array
103
+        isinstance(op, np.ndarray)
104
+        # is complex
105
+        and np.issubdtype(op.dtype, np.complex128)
106
+        # is square
107
+        and op.shape[0] == op.shape[1]
108
+        # has size 2**n, n > 1
109
+        and np.log2(op.shape[0]).is_integer()
110
+        and op.shape[0].bit_length() > 1
111
+    )
112
+
113
+
114
+def _check_valid_operator(op):
115
+    if not is_valid(op):
116
+        raise ValueError("Operator is invalid")
117
+
118
+
119
+def n_qubits(op):
120
+    """Return the number of qubits that the operator acts on.
121
+
122
+    Raises ValueError if 'o' does not have a shape that is
123
+    an integer power of 2.
124
+    """
125
+    _check_valid_operator(op)
126
+    return op.shape[0].bit_length() - 1
... ...
@@ -28,7 +28,9 @@ def select_n_qubits(gate_size):
28 28
     return _strat
29 29
 
30 30
 
31
-valid_complex = st.complex_numbers(allow_infinity=False, allow_nan=False)
31
+valid_complex = st.complex_numbers(
32
+    max_magnitude=1e10, allow_infinity=False, allow_nan=False
33
+)
32 34
 phases = st.floats(
33 35
     min_value=0, max_value=2 * np.pi, allow_nan=False, allow_infinity=False
34 36
 )
... ...
@@ -89,11 +91,6 @@ def test_n_qubits_invalid(gate):
89 91
     # Not size 2**n, n > 0
90 92
     with pytest.raises(ValueError):
91 93
         qsim.gate.n_qubits(gate[:-1, :-1])
92
-    # Not unitary
93
-    nonunitary_part = np.zeros_like(gate)
94
-    nonunitary_part[0, -1] = 1j
95
-    with pytest.raises(ValueError):
96
-        qsim.gate.n_qubits(gate + nonunitary_part)
97 94
 
98 95
 
99 96
 @given(n_qubits, n_qubit_gates)