This includes the following changes:
+ Group the elements of 'where' according to the term/site-families
they belong to. This is pre-computed and stored in '_terms'.
What specifically is stored depends on whether vectorization
is enabled.
+ BlockSparseMatrix now takes its matrix elements as a sequence
of pairs 'which, data', where 'which' indexes 'where'. This
leverages the values stored in '_terms'.
+ Normalize 'onsite' to a function that takes 'site_range' (int)
and 'site_offsets' (array of int). This facilitates vectorization
when it is enabled. Previously 'onsite' was normalized to a
function
taking 'site' (int), which was not vectorizable.
+ '_eval_onsite' is modified to compose components of '_terms' with
'onsite'.
+ '_eval_hamiltonian' is modified to compose compoenents of '_terms'
with 'hamiltonian_term' when vectorized, or 'hamiltonian' when
not vectorized.
... | ... |
@@ -7,6 +7,8 @@ cdef gint _bisect(gint[:] a, gint x) |
7 | 7 |
cdef int _is_herm_conj(complex[:, :] a, complex[:, :] b, |
8 | 8 |
double atol=*, double rtol=*) except -1 |
9 | 9 |
|
10 |
+cdef _select(gint[:, :] arr, gint[:] indexes) |
|
11 |
+ |
|
10 | 12 |
cdef int _check_onsite(complex[:, :] M, gint norbs, |
11 | 13 |
int check_hermiticity) except -1 |
12 | 14 |
|
... | ... |
@@ -28,7 +30,7 @@ cdef class BlockSparseMatrix: |
28 | 30 |
|
29 | 31 |
cdef class _LocalOperator: |
30 | 32 |
cdef public int check_hermiticity, sum |
31 |
- cdef public object syst, onsite, _onsite_param_names |
|
33 |
+ cdef public object syst, onsite, _onsite_param_names, _terms |
|
32 | 34 |
cdef public gint[:, :] where, _site_ranges |
33 | 35 |
cdef public BlockSparseMatrix _bound_onsite, _bound_hamiltonian |
34 | 36 |
|
... | ... |
@@ -1,4 +1,4 @@ |
1 |
-# Copyright 2011-2017 Kwant authors. |
|
1 |
+# Copyright 2011-2019 Kwant authors. |
|
2 | 2 |
# |
3 | 3 |
# This file is part of Kwant. It is subject to the license terms in the file |
4 | 4 |
# LICENSE.rst found in the top-level directory of this distribution and at |
... | ... |
@@ -10,6 +10,7 @@ |
10 | 10 |
__all__ = ['Density', 'Current', 'Source'] |
11 | 11 |
|
12 | 12 |
import cython |
13 |
+import itertools |
|
13 | 14 |
import functools as ft |
14 | 15 |
import collections |
15 | 16 |
import warnings |
... | ... |
@@ -24,7 +25,9 @@ from .graph.core cimport EdgeIterator |
24 | 25 |
from .graph.core import DisabledFeatureError, NodeDoesNotExistError |
25 | 26 |
from .graph.defs cimport gint |
26 | 27 |
from .graph.defs import gint_dtype |
27 |
-from .system import is_infinite, Site |
|
28 |
+from .system import ( |
|
29 |
+ is_infinite, is_vectorized, Site, SiteArray, _normalize_matrix_blocks |
|
30 |
+) |
|
28 | 31 |
from ._common import ( |
29 | 32 |
UserCodeError, KwantDeprecationWarning, get_parameters, deprecate_args |
30 | 33 |
) |
... | ... |
@@ -75,6 +78,20 @@ cdef int _is_herm_conj(complex[:, :] a, complex[:, :] b, |
75 | 78 |
return True |
76 | 79 |
|
77 | 80 |
|
81 |
+ |
|
82 |
+@cython.boundscheck(False) |
|
83 |
+@cython.wraparound(False) |
|
84 |
+cdef _select(gint[:, :] arr, gint[:] indexes): |
|
85 |
+ ret = np.empty((indexes.shape[0], arr.shape[1]), dtype=gint_dtype) |
|
86 |
+ cdef gint[:, :] ret_view = ret |
|
87 |
+ cdef gint i, j |
|
88 |
+ |
|
89 |
+ for i in range(indexes.shape[0]): |
|
90 |
+ for j in range(arr.shape[1]): |
|
91 |
+ ret_view[i, j] = arr[indexes[i], j] |
|
92 |
+ |
|
93 |
+ return ret |
|
94 |
+ |
|
78 | 95 |
################ Helper functions |
79 | 96 |
|
80 | 97 |
_shape_msg = ('{0} matrix dimensions do not match ' |
... | ... |
@@ -250,23 +267,78 @@ def _normalize_hopping_where(syst, where): |
250 | 267 |
return where |
251 | 268 |
|
252 | 269 |
|
253 |
-## These two classes are here to avoid using closures, as these will |
|
270 |
+## These four classes are here to avoid using closures, as these will |
|
254 | 271 |
## break pickle support. These are only used inside '_normalize_onsite'. |
255 | 272 |
|
273 |
+def _raise_user_error(exc, func, vectorized): |
|
274 |
+ msg = [ |
|
275 |
+ 'Error occurred in user-supplied onsite function "{0}".', |
|
276 |
+ 'Did you remember to vectorize "{0}"?' if vectorized else '', |
|
277 |
+ 'See the upper part of the above backtrace for more information.', |
|
278 |
+ ] |
|
279 |
+ msg = '\n'.join(line for line in msg if line).format(func.__name__) |
|
280 |
+ raise UserCodeError(msg) from exc |
|
281 |
+ |
|
282 |
+ |
|
256 | 283 |
class _FunctionalOnsite: |
257 | 284 |
|
258 |
- def __init__(self, onsite, sites): |
|
285 |
+ def __init__(self, onsite, sites, site_ranges): |
|
259 | 286 |
self.onsite = onsite |
260 | 287 |
self.sites = sites |
288 |
+ self.site_ranges = site_ranges |
|
261 | 289 |
|
262 |
- def __call__(self, site_id, *args): |
|
263 |
- return self.onsite(self.sites[site_id], *args) |
|
290 |
+ def __call__(self, site_range, site_offsets, *args): |
|
291 |
+ sites = self.sites |
|
292 |
+ offset = self.site_ranges[site_range][0] |
|
293 |
+ try: |
|
294 |
+ ret = [self.onsite(sites[offset + i], *args) for i in site_offsets] |
|
295 |
+ except Exception as exc: |
|
296 |
+ _raise_user_error(exc, self.onsite, vectorized=False) |
|
297 |
+ return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
264 | 298 |
|
265 | 299 |
|
266 |
-class _DictOnsite(_FunctionalOnsite): |
|
300 |
+class _VectorizedFunctionalOnsite: |
|
267 | 301 |
|
268 |
- def __call__(self, site_id, *args): |
|
269 |
- return self.onsite[self.sites[site_id].family] |
|
302 |
+ def __init__(self, onsite, site_arrays): |
|
303 |
+ self.onsite = onsite |
|
304 |
+ self.site_arrays = site_arrays |
|
305 |
+ |
|
306 |
+ def __call__(self, site_range, site_offsets, *args): |
|
307 |
+ site_array = self.site_arrays[site_range] |
|
308 |
+ tags = site_array.tags[site_offsets] |
|
309 |
+ sites = SiteArray(site_array.family, tags) |
|
310 |
+ try: |
|
311 |
+ ret = self.onsite(sites, *args) |
|
312 |
+ except Exception as exc: |
|
313 |
+ _raise_user_error(exc, self.onsite, vectorized=True) |
|
314 |
+ return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
315 |
+ |
|
316 |
+ |
|
317 |
+class _FunctionalOnsiteNoTransform: |
|
318 |
+ |
|
319 |
+ def __init__(self, onsite, site_ranges): |
|
320 |
+ self.onsite = onsite |
|
321 |
+ self.site_ranges = site_ranges |
|
322 |
+ |
|
323 |
+ def __call__(self, site_range, site_offsets, *args): |
|
324 |
+ site_ids = self.site_ranges[site_range][0] + site_offsets |
|
325 |
+ try: |
|
326 |
+ ret = [self.onsite(id, *args) for id in site_ids] |
|
327 |
+ except Exception as exc: |
|
328 |
+ _raise_user_error(exc, self.onsite, vectorized=False) |
|
329 |
+ return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
330 |
+ |
|
331 |
+ |
|
332 |
+class _DictOnsite: |
|
333 |
+ |
|
334 |
+ def __init__(self, onsite, range_family_map): |
|
335 |
+ self.onsite = onsite |
|
336 |
+ self.range_family_map = range_family_map |
|
337 |
+ |
|
338 |
+ def __call__(self, site_range, site_offsets, *args): |
|
339 |
+ fam = self.range_family_map[site_range] |
|
340 |
+ ret = [self.onsite[fam]] * len(site_offsets) |
|
341 |
+ return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
270 | 342 |
|
271 | 343 |
|
272 | 344 |
def _normalize_onsite(syst, onsite, check_hermiticity): |
... | ... |
@@ -280,21 +352,29 @@ def _normalize_onsite(syst, onsite, check_hermiticity): |
280 | 352 |
if callable(onsite): |
281 | 353 |
# make 'onsite' compatible with hamiltonian value functions |
282 | 354 |
param_names = get_parameters(onsite)[1:] |
283 |
- try: |
|
284 |
- _onsite = _FunctionalOnsite(onsite, syst.sites) |
|
285 |
- except AttributeError: |
|
286 |
- _onsite = onsite |
|
287 |
- elif isinstance(onsite, collections.abc.Mapping): |
|
288 |
- if not hasattr(syst, 'sites'): |
|
289 |
- raise TypeError('Provide `onsite` as a value or a function for ' |
|
290 |
- 'systems that are not finalized Builders.') |
|
355 |
+ if is_vectorized(syst): |
|
356 |
+ _onsite = _VectorizedFunctionalOnsite(onsite, syst.site_arrays) |
|
357 |
+ elif hasattr(syst, "sites"): # probably a non-vectorized finalized Builder |
|
358 |
+ _onsite = _FunctionalOnsite(onsite, syst.sites, syst.site_ranges) |
|
359 |
+ else: # generic old-style system, therefore *not* vectorized. |
|
360 |
+ _onsite = _FunctionalOnsiteNoTransform(onsite, syst.site_ranges) |
|
291 | 361 |
|
362 |
+ elif isinstance(onsite, collections.abc.Mapping): |
|
292 | 363 |
# onsites known; immediately check for correct shape and hermiticity |
293 | 364 |
for fam, _onsite in onsite.items(): |
294 | 365 |
_onsite = ta.matrix(_onsite, complex) |
295 | 366 |
_check_onsite(_onsite, fam.norbs, check_hermiticity) |
296 | 367 |
|
297 |
- _onsite = _DictOnsite(onsite, syst.sites) |
|
368 |
+ if is_vectorized(syst): |
|
369 |
+ range_family_map = [arr.family for arr in syst.site_arrays] |
|
370 |
+ elif not hasattr(syst, 'sites'): |
|
371 |
+ raise TypeError('Provide `onsite` as a value or a function for ' |
|
372 |
+ 'systems that are not finalized Builders.') |
|
373 |
+ else: |
|
374 |
+ # The last entry in 'site_ranges' is just an end marker, so we remove it |
|
375 |
+ range_family_map = [syst.sites[r[0]].family for r in syst.site_ranges[:-1]] |
|
376 |
+ _onsite = _DictOnsite(onsite, range_family_map) |
|
377 |
+ |
|
298 | 378 |
else: |
299 | 379 |
# single onsite; immediately check for correct shape and hermiticity |
300 | 380 |
_onsite = ta.matrix(onsite, complex) |
... | ... |
@@ -320,6 +400,101 @@ def _normalize_onsite(syst, onsite, check_hermiticity): |
320 | 400 |
return _onsite, param_names |
321 | 401 |
|
322 | 402 |
|
403 |
+def _make_onsite_or_hopping_terms(site_ranges, where): |
|
404 |
+ |
|
405 |
+ terms = dict() |
|
406 |
+ |
|
407 |
+ cdef gint[:] site_offsets_ = np.asarray(site_ranges, dtype=gint_dtype)[:, 0] |
|
408 |
+ cdef gint i |
|
409 |
+ |
|
410 |
+ if where.shape[1] == 1: # onsite |
|
411 |
+ for i in range(where.shape[0]): |
|
412 |
+ a = _bisect(site_offsets_, where[i, 0]) - 1 |
|
413 |
+ terms.setdefault((a, a), []).append(i) |
|
414 |
+ else: # hopping |
|
415 |
+ for i in range(where.shape[0]): |
|
416 |
+ a = _bisect(site_offsets_, where[i, 0]) - 1 |
|
417 |
+ b = _bisect(site_offsets_, where[i, 1]) - 1 |
|
418 |
+ terms.setdefault((a, b), []).append(i) |
|
419 |
+ return [(a, None, b) for a, b in terms.items()] |
|
420 |
+ |
|
421 |
+ |
|
422 |
+def _vectorized_make_onsite_terms(syst, where): |
|
423 |
+ assert is_vectorized(syst) |
|
424 |
+ assert where.shape[1] == 1 |
|
425 |
+ site_offsets = [r[0] for r in syst.site_ranges] |
|
426 |
+ |
|
427 |
+ terms = {} |
|
428 |
+ term_by_id = syst._onsite_term_by_site_id |
|
429 |
+ for i in range(where.shape[0]): |
|
430 |
+ w = where[i, 0] |
|
431 |
+ terms.setdefault(term_by_id[w], []).append(i) |
|
432 |
+ |
|
433 |
+ ret = [] |
|
434 |
+ for term_id, which in terms.items(): |
|
435 |
+ term = syst.terms[term_id] |
|
436 |
+ ((term_sa, _), (term_sites, _)) = syst.subgraphs[term.subgraph] |
|
437 |
+ term_sites += site_offsets[term_sa] |
|
438 |
+ which = np.asarray(which, dtype=gint_dtype) |
|
439 |
+ sites = _select(where, which).reshape(-1) |
|
440 |
+ selector = np.searchsorted(term_sites, sites) |
|
441 |
+ ret.append((term_id, selector, which)) |
|
442 |
+ |
|
443 |
+ return ret |
|
444 |
+ |
|
445 |
+ |
|
446 |
+def _vectorized_make_hopping_terms(syst, where): |
|
447 |
+ assert is_vectorized(syst) |
|
448 |
+ assert where.shape[1] == 2 |
|
449 |
+ site_offsets = [r[0] for r in syst.site_ranges] |
|
450 |
+ |
|
451 |
+ terms = {} |
|
452 |
+ term_by_id = syst._hopping_term_by_edge_id |
|
453 |
+ for i in range(where.shape[0]): |
|
454 |
+ a, b = where[i, 0], where[i, 1] |
|
455 |
+ edge = syst.graph.first_edge_id(a, b) |
|
456 |
+ terms.setdefault(term_by_id[edge], []).append(i) |
|
457 |
+ |
|
458 |
+ ret = [] |
|
459 |
+ dtype = np.dtype([('f0', int), ('f1', int)]) |
|
460 |
+ for term_id, which in terms.items(): |
|
461 |
+ herm_conj = term_id < 0 |
|
462 |
+ if herm_conj: |
|
463 |
+ real_term_id = -term_id - 1 |
|
464 |
+ else: |
|
465 |
+ real_term_id = term_id |
|
466 |
+ which = np.asarray(which, dtype=gint_dtype) |
|
467 |
+ # Select out the hoppings and reverse them if we are |
|
468 |
+ # dealing with Hermitian conjugate hoppings |
|
469 |
+ hops = _select(where, which) |
|
470 |
+ if herm_conj: |
|
471 |
+ hops = hops[:, ::-1] |
|
472 |
+ # Force contiguous array to handle herm conj case also. |
|
473 |
+ # Needs to be contiguous to cast to compound dtype |
|
474 |
+ hops = np.ascontiguousarray(hops, dtype=int) |
|
475 |
+ hops = hops.view(dtype).reshape(-1) |
|
476 |
+ term = syst.terms[real_term_id] |
|
477 |
+ # Get array of pairs |
|
478 |
+ ((to_sa, from_sa), term_hoppings) = syst.subgraphs[term.subgraph] |
|
479 |
+ term_hoppings = term_hoppings.transpose() + (site_offsets[to_sa], site_offsets[from_sa]) |
|
480 |
+ term_hoppings = term_hoppings.view(dtype).reshape(-1) |
|
481 |
+ |
|
482 |
+ selector = np.recarray.searchsorted(term_hoppings, hops) |
|
483 |
+ |
|
484 |
+ ret.append((term_id, selector, which)) |
|
485 |
+ |
|
486 |
+ return ret |
|
487 |
+ |
|
488 |
+ |
|
489 |
+def _make_matrix_elements(evaluate_term, terms): |
|
490 |
+ matrix_elements = [] |
|
491 |
+ for (term_id, term_selector, which) in terms: |
|
492 |
+ which = np.asarray(which, dtype=gint_dtype) |
|
493 |
+ data = evaluate_term(term_id, term_selector, which) |
|
494 |
+ matrix_elements.append((which, data)) |
|
495 |
+ return matrix_elements |
|
496 |
+ |
|
497 |
+ |
|
323 | 498 |
cdef class BlockSparseMatrix: |
324 | 499 |
"""A sparse matrix stored as dense blocks. |
325 | 500 |
|
... | ... |
@@ -334,11 +509,10 @@ cdef class BlockSparseMatrix: |
334 | 509 |
in the sparse matrix: ``(row_offset, col_offset)``. |
335 | 510 |
block_shapes : gint[:, :] |
336 | 511 |
``Nx2`` array: the shapes of each block, ``(n_rows, n_cols)``. |
337 |
- f : callable |
|
338 |
- evaluates matrix blocks. Has signature ``(a, n_rows, b, n_cols)`` |
|
339 |
- where all the arguments are integers and |
|
340 |
- ``a`` and ``b`` are the contents of ``where``. This function |
|
341 |
- must return a matrix of shape ``(n_rows, n_cols)``. |
|
512 |
+ matrix_elements : sequence of pairs (where_indices, data) |
|
513 |
+ 'data' is a 3D complex array; a vector of matrices. |
|
514 |
+ 'where_indices' is a 1D array of indices for 'where'; |
|
515 |
+ 'data[i]' should be placed at 'where[where_indices[i]]'. |
|
342 | 516 |
|
343 | 517 |
Attributes |
344 | 518 |
---------- |
... | ... |
@@ -357,7 +531,7 @@ cdef class BlockSparseMatrix: |
357 | 531 |
@cython.boundscheck(False) |
358 | 532 |
@cython.wraparound(False) |
359 | 533 |
def __init__(self, gint[:, :] where, gint[:, :] block_offsets, |
360 |
- gint[:, :] block_shapes, f): |
|
534 |
+ gint[:, :] block_shapes, matrix_elements): |
|
361 | 535 |
if (block_offsets.shape[0] != where.shape[0] or |
362 | 536 |
block_shapes.shape[0] != where.shape[0]): |
363 | 537 |
raise ValueError('Arrays should be the same length along ' |
... | ... |
@@ -372,20 +546,19 @@ cdef class BlockSparseMatrix: |
372 | 546 |
data_size += block_shapes[w, 0] * block_shapes[w, 1] |
373 | 547 |
### Populate data array |
374 | 548 |
self.data = np.empty((data_size,), dtype=complex) |
375 |
- cdef complex[:, :] mat |
|
376 |
- cdef gint i, j, off, a, b, a_norbs, b_norbs |
|
377 |
- for w in range(where.shape[0]): |
|
378 |
- off = self.data_offsets[w] |
|
379 |
- a_norbs = self.block_shapes[w, 0] |
|
380 |
- b_norbs = self.block_shapes[w, 1] |
|
381 |
- a = where[w, 0] |
|
382 |
- b = a if where.shape[1] == 1 else where[w, 1] |
|
383 |
- # call the function that gives the matrix |
|
384 |
- mat = f(a, a_norbs, b, b_norbs) |
|
385 |
- # Copy data |
|
386 |
- for i in range(a_norbs): |
|
387 |
- for j in range(b_norbs): |
|
388 |
- self.data[off + i * b_norbs + j] = mat[i, j] |
|
549 |
+ cdef complex[:, :, :] data |
|
550 |
+ cdef gint[:] where_indexes |
|
551 |
+ cdef gint i, j, k, off, a, b, a_norbs, b_norbs |
|
552 |
+ for where_indexes, data in matrix_elements: |
|
553 |
+ for i in range(where_indexes.shape[0]): |
|
554 |
+ w = where_indexes[i] |
|
555 |
+ off = self.data_offsets[w] |
|
556 |
+ a_norbs = self.block_shapes[w, 0] |
|
557 |
+ b_norbs = self.block_shapes[w, 1] |
|
558 |
+ # Copy data |
|
559 |
+ for j in range(a_norbs): |
|
560 |
+ for k in range(b_norbs): |
|
561 |
+ self.data[off + j * b_norbs + k] = data[i, j, k] |
|
389 | 562 |
|
390 | 563 |
cdef complex* get(self, gint block_idx): |
391 | 564 |
return <complex*> &self.data[0] + self.data_offsets[block_idx] |
... | ... |
@@ -470,6 +643,40 @@ cdef class _LocalOperator: |
470 | 643 |
self._bound_onsite = None |
471 | 644 |
self._bound_hamiltonian = None |
472 | 645 |
|
646 |
+ # Here we pre-compute the datastructures that will enable us to evaluate |
|
647 |
+ # the Hamiltonian and onsite functions in a vectorized way. Essentially |
|
648 |
+ # we group the sites/hoppings in 'where' by what term of 'syst' they are |
|
649 |
+ # in (for vectorized systems), or by the site family(s) (for |
|
650 |
+ # non-vectorized systems). If the system is vectorized we store a list |
|
651 |
+ # of triples: |
|
652 |
+ # |
|
653 |
+ # (term_id, term_selector, which) |
|
654 |
+ # |
|
655 |
+ # otherwise |
|
656 |
+ # |
|
657 |
+ # ((to_site_range, from_site_range), None, which) |
|
658 |
+ # |
|
659 |
+ # Where: |
|
660 |
+ # |
|
661 |
+ # 'term_id' → integer: term index, may be negative (indicates herm. conj.) |
|
662 |
+ # 'term_selector' → 1D integer array: selects which elements from the |
|
663 |
+ # subgraph of term number 'term_id' should be evaluated. |
|
664 |
+ # 'which' → 1D integer array: selects which elements of 'where' this |
|
665 |
+ # vectorized evaluation corresponds to. |
|
666 |
+ # 'to/from_site_range' → integer: the site ranges that the elements of |
|
667 |
+ # 'where' indexed by 'which' correspond to. |
|
668 |
+ # |
|
669 |
+ # Note that all sites/hoppings indicated by 'which' belong to the *same* |
|
670 |
+ # pair of site families by construction. This is what allows for |
|
671 |
+ # vectorized evaluation. |
|
672 |
+ if is_vectorized(syst): |
|
673 |
+ if self.where.shape[1] == 1: |
|
674 |
+ self._terms = _vectorized_make_onsite_terms(syst, where) |
|
675 |
+ else: |
|
676 |
+ self._terms = _vectorized_make_hopping_terms(syst, where) |
|
677 |
+ else: |
|
678 |
+ self._terms = _make_onsite_or_hopping_terms(self._site_ranges, where) |
|
679 |
+ |
|
473 | 680 |
@cython.embedsignature |
474 | 681 |
def __call__(self, bra, ket=None, args=(), *, params=None): |
475 | 682 |
r"""Return the matrix elements of the operator. |
... | ... |
@@ -665,9 +872,10 @@ cdef class _LocalOperator: |
665 | 872 |
"""Evaluate the onsite matrices on all elements of `where`""" |
666 | 873 |
assert callable(self.onsite) |
667 | 874 |
assert not (args and params) |
668 |
- matrix = ta.matrix |
|
669 |
- onsite = self.onsite |
|
670 | 875 |
check_hermiticity = self.check_hermiticity |
876 |
+ syst = self.syst |
|
877 |
+ |
|
878 |
+ _is_vectorized = is_vectorized(syst) |
|
671 | 879 |
|
672 | 880 |
if params: |
673 | 881 |
try: |
... | ... |
@@ -679,38 +887,80 @@ cdef class _LocalOperator: |
679 | 887 |
', '.join(map('"{}"'.format, missing))) |
680 | 888 |
raise TypeError(''.join(msg)) |
681 | 889 |
|
682 |
- def get_onsite(a, a_norbs, b, b_norbs): |
|
683 |
- mat = matrix(onsite(a, *args), complex) |
|
684 |
- _check_onsite(mat, a_norbs, check_hermiticity) |
|
685 |
- return mat |
|
686 |
- |
|
890 |
+ # Evaluate many onsites at once. See _LocalOperator.__init__ |
|
891 |
+ # for an explanation the parameters. |
|
892 |
+ def eval_onsite(term_id, term_selector, which): |
|
893 |
+ if _is_vectorized: |
|
894 |
+ if term_id >= 0: |
|
895 |
+ (sr, _), _ = syst.subgraphs[syst.terms[term_id].subgraph] |
|
896 |
+ else: |
|
897 |
+ (_, sr), _ = syst.subgraphs[syst.terms[-term_id - 1].subgraph] |
|
898 |
+ else: |
|
899 |
+ sr, _ = term_id |
|
900 |
+ start_site, norbs, _ = self.syst.site_ranges[sr] |
|
901 |
+ # All sites selected by 'which' are part of the same site family. |
|
902 |
+ site_offsets = _select(self.where, which)[:, 0] - start_site |
|
903 |
+ data = self.onsite(sr, site_offsets, *args) |
|
904 |
+ return data |
|
905 |
+ |
|
906 |
+ matrix_elements = _make_matrix_elements(eval_onsite, self._terms) |
|
687 | 907 |
offsets, norbs = _get_all_orbs(self.where, self._site_ranges) |
688 |
- return BlockSparseMatrix(self.where, offsets, norbs, get_onsite) |
|
908 |
+ return BlockSparseMatrix(self.where, offsets, norbs, matrix_elements) |
|
909 |
+ |
|
689 | 910 |
|
690 | 911 |
cdef BlockSparseMatrix _eval_hamiltonian(self, args, params): |
691 | 912 |
"""Evaluate the Hamiltonian on all elements of `where`.""" |
692 |
- matrix = ta.matrix |
|
693 |
- hamiltonian = self.syst.hamiltonian |
|
913 |
+ |
|
914 |
+ where = self.where |
|
915 |
+ syst = self.syst |
|
916 |
+ is_onsite = self.where.shape[1] == 1 |
|
694 | 917 |
check_hermiticity = self.check_hermiticity |
695 | 918 |
|
696 |
- def get_ham(a, a_norbs, b, b_norbs): |
|
697 |
- mat = matrix(hamiltonian(a, b, *args, params=params), complex) |
|
698 |
- _check_ham(mat, hamiltonian, args, params, |
|
699 |
- a, a_norbs, b, b_norbs, check_hermiticity) |
|
700 |
- return mat |
|
919 |
+ if is_vectorized(self.syst): |
|
701 | 920 |
|
702 |
- offsets, norbs = _get_all_orbs(self.where, self._site_ranges) |
|
703 |
- # TODO: update operators to use 'hamiltonian_term' rather than |
|
704 |
- # 'hamiltonian'. |
|
705 |
- with warnings.catch_warnings(): |
|
706 |
- warnings.simplefilter("ignore", category=KwantDeprecationWarning) |
|
707 |
- return BlockSparseMatrix(self.where, offsets, norbs, get_ham) |
|
921 |
+ # Evaluate many Hamiltonian elements at once. |
|
922 |
+ # See _LocalOperator.__init__ for an explanation the parameters. |
|
923 |
+ def eval_hamiltonian(term_id, term_selector, which): |
|
924 |
+ herm_conj = term_id < 0 |
|
925 |
+ assert not is_onsite or (is_onsite and not herm_conj) # onsite terms are never hermitian conjugate |
|
926 |
+ if herm_conj: |
|
927 |
+ term_id = -term_id - 1 |
|
928 |
+ data = syst.hamiltonian_term(term_id, term_selector, |
|
929 |
+ args=args, params=params) |
|
930 |
+ if herm_conj: |
|
931 |
+ data = data.conjugate().transpose(0, 2, 1) |
|
932 |
+ |
|
933 |
+ return data |
|
934 |
+ |
|
935 |
+ else: |
|
936 |
+ |
|
937 |
+ # Evaluate many Hamiltonian elements at once. |
|
938 |
+ # See _LocalOperator.__init__ for an explanation the parameters. |
|
939 |
+ def eval_hamiltonian(term_id, term_selector, which): |
|
940 |
+ if is_onsite: |
|
941 |
+ data = [ |
|
942 |
+ syst.hamiltonian(where[i, 0], where[i, 0], *args, params=params) |
|
943 |
+ for i in which |
|
944 |
+ ] |
|
945 |
+ else: |
|
946 |
+ data = [ |
|
947 |
+ syst.hamiltonian(where[i, 0], where[i, 1], *args, params=params) |
|
948 |
+ for i in which |
|
949 |
+ ] |
|
950 |
+ data = _normalize_matrix_blocks(data, len(which)) |
|
951 |
+ |
|
952 |
+ return data |
|
953 |
+ |
|
954 |
+ matrix_elements = _make_matrix_elements(eval_hamiltonian, self._terms) |
|
955 |
+ offsets, norbs = _get_all_orbs(where, self._site_ranges) |
|
956 |
+ return BlockSparseMatrix(where, offsets, norbs, matrix_elements) |
|
708 | 957 |
|
709 | 958 |
def __getstate__(self): |
710 | 959 |
return ( |
711 | 960 |
(self.check_hermiticity, self.sum), |
712 | 961 |
(self.syst, self.onsite, self._onsite_param_names), |
713 | 962 |
tuple(map(np.asarray, (self.where, self._site_ranges))), |
963 |
+ (self._terms,), |
|
714 | 964 |
(self._bound_onsite, self._bound_hamiltonian), |
715 | 965 |
) |
716 | 966 |
|
... | ... |
@@ -718,6 +968,7 @@ cdef class _LocalOperator: |
718 | 968 |
((self.check_hermiticity, self.sum), |
719 | 969 |
(self.syst, self.onsite, self._onsite_param_names), |
720 | 970 |
(self.where, self._site_ranges), |
971 |
+ (self._terms,), |
|
721 | 972 |
(self._bound_onsite, self._bound_hamiltonian), |
722 | 973 |
) = state |
723 | 974 |
|