Browse code

add consistency checks when computing onsite and hamiltonian data

We get rid of '_is_herm_conj' in favour of '_is_hermitian', replace
'_check_ham' with '_check_hams' (which works on vectorized values)
and add '_check_onsites' (which also works on vectorized values).
Replace absolute value calculation with call to 'cabs' from 'complex.h'

Joseph Weston authored on 26/11/2019 12:44:43
Showing 2 changed files
... ...
@@ -4,17 +4,24 @@ from .graph.defs import gint_dtype
4 4
 
5 5
 cdef gint _bisect(gint[:] a, gint x)
6 6
 
7
-cdef int _is_herm_conj(complex[:, :] a, complex[:, :] b,
8
-                       double atol=*, double rtol=*) except -1
7
+cdef int _is_hermitian(
8
+    complex[:, :] a, double atol=*, double rtol=*
9
+) except -1
10
+
11
+cdef int _is_hermitian_3d(
12
+    complex[:, :, :] a, double atol=*, double rtol=*
13
+) except -1
9 14
 
10 15
 cdef _select(gint[:, :] arr, gint[:] indexes)
11 16
 
12 17
 cdef int _check_onsite(complex[:, :] M, gint norbs,
13 18
                        int check_hermiticity) except -1
14 19
 
15
-cdef int _check_ham(complex[:, :] H, ham, args, params,
16
-                    gint a, gint a_norbs, gint b, gint b_norbs,
17
-                    int check_hermiticity) except -1
20
+cdef int _check_onsites(complex[:, :, :] M, gint norbs,
21
+                        int check_hermiticity) except -1
22
+
23
+cdef int _check_hams(complex[:, :, :] H, gint to_norbs, gint from_norbs,
24
+                     int check_hermiticity) except -1
18 25
 
19 26
 cdef void _get_orbs(gint[:, :] site_ranges, gint site,
20 27
                     gint *start_orb, gint *norbs)
... ...
@@ -21,6 +21,9 @@ from scipy.sparse import coo_matrix
21 21
 
22 22
 from libc cimport math
23 23
 
24
+cdef extern from "complex.h":
25
+    double cabs(double complex)
26
+
24 27
 from .graph.core cimport EdgeIterator
25 28
 from .graph.core import DisabledFeatureError, NodeDoesNotExistError
26 29
 from .graph.defs cimport gint
... ...
@@ -51,32 +54,61 @@ cdef gint _bisect(gint[:] a, gint x):
51 54
 
52 55
 @cython.boundscheck(False)
53 56
 @cython.wraparound(False)
54
-cdef int _is_herm_conj(complex[:, :] a, complex[:, :] b,
55
-                       double atol=1e-300, double rtol=1e-13) except -1:
56
-    "Return True if `a` is the Hermitian conjugate of `b`."
57
-    assert a.shape[0] == b.shape[1]
58
-    assert a.shape[1] == b.shape[0]
57
+cdef int _is_hermitian(
58
+    complex[:, :] a, double atol=1e-300, double rtol=1e-13
59
+) except -1:
60
+    "Return True if 'a' is Hermitian"
61
+
62
+    if a.shape[0] != a.shape[1]:
63
+        return False
59 64
 
60 65
     # compute max(a)
61 66
     cdef double tmp, max_a = 0
62
-    cdef gint i, j
67
+    cdef gint i, j, k
63 68
     for i in range(a.shape[0]):
64 69
         for j in range(a.shape[1]):
65
-            tmp = a[i, j].real * a[i, j].real + a[i, j].imag * a[i, j].imag
70
+            tmp = cabs(a[i, j])
66 71
             if tmp > max_a:
67 72
                 max_a = tmp
68 73
     max_a = math.sqrt(max_a)
69 74
 
70 75
     cdef double tol = rtol * max_a + atol
71
-    cdef complex ctmp
72 76
     for i in range(a.shape[0]):
73
-        for j in range(a.shape[1]):
74
-            ctmp = a[i, j] - b[j, i].conjugate()
75
-            tmp = ctmp.real * ctmp.real + ctmp.imag * ctmp.imag
77
+        for j in range(i, a.shape[1]):
78
+            tmp = cabs(a[i, j] - a[j, i].conjugate())
76 79
             if tmp > tol:
77 80
                 return False
78 81
     return True
79 82
 
83
+@cython.boundscheck(False)
84
+@cython.wraparound(False)
85
+cdef int _is_hermitian_3d(
86
+    complex[:, :, :] a, double atol=1e-300, double rtol=1e-13
87
+) except -1:
88
+    "Return True if 'a' is Hermitian"
89
+
90
+    if a.shape[1] != a.shape[2]:
91
+        return False
92
+
93
+    # compute max(a)
94
+    cdef double tmp, max_a = 0
95
+    cdef gint i, j, k
96
+    for k in range(a.shape[0]):
97
+        for i in range(a.shape[1]):
98
+            for j in range(a.shape[2]):
99
+                tmp = cabs(a[k, i, j])
100
+                if tmp > max_a:
101
+                    max_a = tmp
102
+    max_a = math.sqrt(max_a)
103
+
104
+    cdef double tol = rtol * max_a + atol
105
+    for k in range(a.shape[0]):
106
+        for i in range(a.shape[1]):
107
+            for j in range(i, a.shape[2]):
108
+                tmp = cabs(a[k, i, j] - a[k, j, i].conjugate())
109
+                if tmp > tol:
110
+                    return False
111
+    return True
80 112
 
81 113
 
82 114
 @cython.boundscheck(False)
... ...
@@ -107,22 +139,28 @@ cdef int _check_onsite(complex[:, :] M, gint norbs,
107 139
         raise UserCodeError('Onsite matrix is not square')
108 140
     if M.shape[0] != norbs:
109 141
         raise UserCodeError(_shape_msg.format('Onsite'))
110
-    if check_hermiticity and not _is_herm_conj(M, M):
142
+    if check_hermiticity and not _is_hermitian(M):
111 143
         raise ValueError(_herm_msg.format('Onsite'))
112 144
     return 0
113 145
 
114 146
 
115
-cdef int _check_ham(complex[:, :] H, ham, args, params,
116
-                    gint a, gint a_norbs, gint b, gint b_norbs,
117
-                    int check_hermiticity) except -1:
118
-    "Check Hamiltonian matrix for correct shape and hermiticity."
119
-    if H.shape[0] != a_norbs and H.shape[1] != b_norbs:
147
+cdef int _check_onsites(complex[:, :, :] M, gint norbs,
148
+                       int check_hermiticity) except -1:
149
+    "Check onsite matrix for correct shape and hermiticity."
150
+    if M.shape[1] != M.shape[2]:
151
+        raise UserCodeError('Onsite matrix is not square')
152
+    if M.shape[1] != norbs:
153
+        raise UserCodeError(_shape_msg.format('Onsite'))
154
+    if check_hermiticity and not _is_hermitian_3d(M):
155
+        raise ValueError(_herm_msg.format('Onsite'))
156
+    return 0
157
+
158
+
159
+cdef int _check_hams(complex[:, :, :] H, gint to_norbs, gint from_norbs,
160
+                     int check_hermiticity) except -1:
161
+    if H.shape[1] != to_norbs or H.shape[2] != from_norbs:
120 162
         raise UserCodeError(_shape_msg.format('Hamiltonian'))
121
-    if check_hermiticity:
122
-        # call the "partner" element if we are not on the diagonal
123
-        H_conj = H if a == b else ta.matrix(ham(b, a, *args, params=params),
124
-                                                complex)
125
-        if not _is_herm_conj(H_conj, H):
163
+    if check_hermiticity and not _is_hermitian_3d(H):
126 164
             raise ValueError(_herm_msg.format('Hamiltonian'))
127 165
     return 0
128 166
 
... ...
@@ -901,6 +939,7 @@ cdef class _LocalOperator:
901 939
             # All sites selected by 'which' are part of the same site family.
902 940
             site_offsets = _select(self.where, which)[:, 0] - start_site
903 941
             data = self.onsite(sr, site_offsets, *args)
942
+            _check_onsites(data, norbs, self.check_hermiticity)
904 943
             return data
905 944
 
906 945
         matrix_elements = _make_matrix_elements(eval_onsite, self._terms)
... ...
@@ -929,6 +968,13 @@ cdef class _LocalOperator:
929 968
                                              args=args, params=params)
930 969
                 if herm_conj:
931 970
                     data = data.conjugate().transpose(0, 2, 1)
971
+                # Checks for data consistency
972
+                (to_sr, from_sr), _ = syst.subgraphs[syst.terms[term_id].subgraph]
973
+                to_norbs = syst.site_ranges[to_sr][1]
974
+                from_norbs = syst.site_ranges[from_sr][1]
975
+                if herm_conj:
976
+                    to_norbs, from_norbs = from_norbs, to_norbs
977
+                _check_hams(data, to_norbs, from_norbs, is_onsite and check_hermiticity)
932 978
 
933 979
                 return data
934 980
 
... ...
@@ -948,6 +994,11 @@ cdef class _LocalOperator:
948 994
                         for i in which
949 995
                     ]
950 996
                 data = _normalize_matrix_blocks(data, len(which))
997
+                # Checks for data consistency
998
+                (to_sr, from_sr) = term_id
999
+                to_norbs = syst.site_ranges[to_sr][1]
1000
+                from_norbs = syst.site_ranges[from_sr][1]
1001
+                _check_hams(data, to_norbs, from_norbs, is_onsite and check_hermiticity)
951 1002
 
952 1003
                 return data
953 1004