implement vectorized value broadcasting
Closes #342
See merge request kwant/kwant!346
... | ... |
@@ -742,33 +742,37 @@ def is_vectorized(syst): |
742 | 742 |
return isinstance(syst, (FiniteVectorizedSystem, InfiniteVectorizedSystem)) |
743 | 743 |
|
744 | 744 |
|
745 |
-def _normalize_matrix_blocks(matrix_blocks, expected_length): |
|
745 |
+def _normalize_matrix_blocks(blocks, expected_length): |
|
746 | 746 |
"""Normalize a sequence of matrices into a single 3D numpy array |
747 | 747 |
|
748 | 748 |
Parameters |
749 | 749 |
---------- |
750 |
- matrix_blocks : sequence of complex array-like |
|
750 |
+ blocks : sequence of complex array-like |
|
751 | 751 |
expected_length : int |
752 | 752 |
""" |
753 | 753 |
try: |
754 |
- matrix_blocks = np.asarray(matrix_blocks, dtype=complex) |
|
754 |
+ blocks = np.asarray(blocks, dtype=complex) |
|
755 | 755 |
except TypeError: |
756 | 756 |
raise ValueError( |
757 | 757 |
"Matrix elements declared with incompatible shapes." |
758 | 758 |
) from None |
759 |
- # Upgrade to vector of matrices |
|
760 |
- if len(matrix_blocks.shape) == 1: |
|
761 |
- matrix_blocks = matrix_blocks[:, np.newaxis, np.newaxis] |
|
762 |
- if len(matrix_blocks.shape) != 3: |
|
759 |
+ if len(blocks.shape) == 0: # scalar → broadcast to vector of 1x1 matrices |
|
760 |
+ blocks = np.tile(blocks, (expected_length, 1, 1)) |
|
761 |
+ elif len(blocks.shape) == 1: # vector → interpret as vector of 1x1 matrices |
|
762 |
+ blocks = blocks.reshape(-1, 1, 1) |
|
763 |
+ elif len(blocks.shape) == 2: # matrix → broadcast to vector of matrices |
|
764 |
+ blocks = np.tile(blocks, (expected_length, 1, 1)) |
|
765 |
+ |
|
766 |
+ if len(blocks.shape) != 3: |
|
763 | 767 |
msg = ( |
764 | 768 |
"Vectorized value functions must return an array of" |
765 | 769 |
"scalars or an array of matrices." |
766 | 770 |
) |
767 | 771 |
raise ValueError(msg) |
768 |
- if matrix_blocks.shape[0] != expected_length: |
|
772 |
+ if blocks.shape[0] != expected_length: |
|
769 | 773 |
raise ValueError("Value functions must return a single value per " |
770 | 774 |
"onsite/hopping.") |
771 |
- return matrix_blocks |
|
775 |
+ return blocks |
|
772 | 776 |
|
773 | 777 |
|
774 | 778 |
|
... | ... |
@@ -740,6 +740,46 @@ def test_vectorized_hamiltonian_evaluation(): |
740 | 740 |
) |
741 | 741 |
|
742 | 742 |
|
743 |
+def test_vectorized_value_normalization(): |
|
744 |
+ # Here we test whether all legal shapes for values for vectorized Builders |
|
745 |
+ # are accepted: |
|
746 |
+ # + single scalars are interpreted as 1x1 matrices, and broadcast into an |
|
747 |
+ # (N, 1, 1) array |
|
748 |
+ # + single (n, m) matrices are broadcast into an (N, n, m) array |
|
749 |
+ # + a shape (N,) array is interpreted as an (N, 1, 1) array |
|
750 |
+ |
|
751 |
+ sz = np.array([[1, 0], [0, -1]]) |
|
752 |
+ |
|
753 |
+ scalars = [ |
|
754 |
+ 2, |
|
755 |
+ lambda s: 2, |
|
756 |
+ lambda s: [2] * len(s) |
|
757 |
+ ] |
|
758 |
+ matrices = [ |
|
759 |
+ sz, |
|
760 |
+ lambda s: sz, |
|
761 |
+ lambda s: [sz] * len(s) |
|
762 |
+ ] |
|
763 |
+ inter_lattice_matrices =[ |
|
764 |
+ [[1, 0]], |
|
765 |
+ lambda a, b: [[1, 0]], |
|
766 |
+ lambda a, b: [[[1, 0]]] * len(a) |
|
767 |
+ ] |
|
768 |
+ |
|
769 |
+ lat_a = kwant.lattice.chain(norbs=1) |
|
770 |
+ lat_b = kwant.lattice.chain(norbs=2) |
|
771 |
+ |
|
772 |
+ hams = [] |
|
773 |
+ for s, m, v in it.product(scalars, matrices, inter_lattice_matrices): |
|
774 |
+ syst = builder.Builder(vectorize=True) |
|
775 |
+ syst[map(lat_a, range(10))] = s |
|
776 |
+ syst[map(lat_b, range(10))] = m |
|
777 |
+ syst[((lat_a(i), lat_b(i)) for i in range(10))] = v |
|
778 |
+ h = syst.finalized().hamiltonian_submatrix() |
|
779 |
+ hams.append(h) |
|
780 |
+ assert np.all(hams == hams[0]) |
|
781 |
+ |
|
782 |
+ |
|
743 | 783 |
@pytest.mark.parametrize("sym", [ |
744 | 784 |
builder.NoSymmetry(), |
745 | 785 |
kwant.TranslationalSymmetry([-1]), |