Browse code

Implement application of a gate to a state.

Joseph Weston authored on 15/11/2019 20:41:21
Showing 1 changed files
... ...
@@ -7,6 +7,7 @@ matrix written in the computational basis.
7 7
 import numpy as np
8 8
 
9 9
 __all__ = [
10
+    "apply",
10 11
     "n_qubits",
11 12
     "controlled",
12 13
     # -- Single qubit gates --
... ...
@@ -28,6 +29,38 @@ __all__ = [
28 29
 ]  # type: ignore
29 30
 
30 31
 
32
+def apply(gate, qubits, state):
33
+    n_gate_qubits = gate.shape[0].bit_length() - 1
34
+    n_state_qubits = state.shape[0].bit_length() - 1
35
+    assert len(qubits) == n_gate_qubits
36
+
37
+    # We can view an n-qubit gate as a 2*n-tensor (n contravariant and n contravariant
38
+    # indices) and an n-qubit state as an n-tensor (contravariant indices)
39
+    # with each axis having length 2 (the state space of a single qubit).
40
+    gate = gate.reshape((2,) * 2 * n_gate_qubits)
41
+    state = state.reshape((2,) * n_state_qubits)
42
+
43
+    # Our qubits are labeled from least significant to most significant, i.e. our
44
+    # computational basis (for e.g. 2 qubits) is ordered like |00⟩, |01⟩, |10⟩, |11⟩.
45
+    # We represent the state as a tensor in *row-major* order; this means that the
46
+    # axis ordering is *backwards* compared to the qubit ordering (the least significant
47
+    # qubit corresponds to the *last* axis in the tensor etc.)
48
+    qubit_axes = tuple(n_state_qubits - 1 - np.asarray(qubits))
49
+
50
+    # Applying the gate to the state vector is then the tensor product over the appropriate axes
51
+    axes = (np.arange(n_gate_qubits, 2 * n_gate_qubits), qubit_axes)
52
+    new_state = np.tensordot(gate, state, axes=axes)
53
+
54
+    # tensordot effectively re-orders the qubits such that the qubits we operated
55
+    # on are in the most-significant positions (i.e. their axes come first). We
56
+    # thus need to transpose the axes to place them back into their original positions.
57
+    untouched_axes = tuple(
58
+        idx for idx in range(n_state_qubits) if idx not in qubit_axes
59
+    )
60
+    inverse_permutation = np.argsort(qubit_axes + untouched_axes)
61
+    return np.transpose(new_state, inverse_permutation).reshape(-1)
62
+
63
+
31 64
 def _check_valid_gate(gate):
32 65
     if not (
33 66
         # is an array