... | ... |
@@ -2122,7 +2122,12 @@ class _VectorizedFinalizedBuilderMixin(_FinalizedBuilderMixin): |
2122 | 2122 |
except Exception as exc: |
2123 | 2123 |
_raise_user_error(exc, val) |
2124 | 2124 |
|
2125 |
- ham = system._normalize_matrix_blocks(ham, len(to_site_array)) |
|
2125 |
+ expected_shape = ( |
|
2126 |
+ len(to_site_array), |
|
2127 |
+ to_family.norbs, |
|
2128 |
+ from_family.norbs if not is_onsite else to_family.norbs, |
|
2129 |
+ ) |
|
2130 |
+ ham = system._normalize_matrix_blocks(ham, expected_shape) |
|
2126 | 2131 |
|
2127 | 2132 |
return ham |
2128 | 2133 |
|
... | ... |
@@ -2446,8 +2451,17 @@ def _make_onsite_terms(builder, sites, site_arrays, term_offset): |
2446 | 2451 |
# Normalize any constant values and check that the shapes are consistent. |
2447 | 2452 |
onsite_term_values = [ |
2448 | 2453 |
val if callable(val) |
2449 |
- else system._normalize_matrix_blocks(val, len(val)) |
|
2450 |
- for val in onsite_term_values |
|
2454 |
+ else |
|
2455 |
+ system._normalize_matrix_blocks( |
|
2456 |
+ val, |
|
2457 |
+ ( |
|
2458 |
+ len(val), |
|
2459 |
+ site_arrays[sa].family.norbs, |
|
2460 |
+ site_arrays[sa].family.norbs, |
|
2461 |
+ ), |
|
2462 |
+ ) |
|
2463 |
+ for (_, sa), val in |
|
2464 |
+ zip(onsite_to_term_nr.keys(), onsite_term_values) |
|
2451 | 2465 |
] |
2452 | 2466 |
# Sort the sites in each term, and also sort the values |
2453 | 2467 |
# in the same way if they are a constant (as opposed to a callable). |
... | ... |
@@ -2560,8 +2574,17 @@ def _make_hopping_terms(builder, graph, sites, site_arrays, cell_size, term_offs |
2560 | 2574 |
# Normalize any constant values and check that the shapes are consistent. |
2561 | 2575 |
hopping_term_values = [ |
2562 | 2576 |
val if callable(val) |
2563 |
- else system._normalize_matrix_blocks(val, len(val)) |
|
2564 |
- for val in hopping_term_values |
|
2577 |
+ else |
|
2578 |
+ system._normalize_matrix_blocks( |
|
2579 |
+ val, |
|
2580 |
+ ( |
|
2581 |
+ len(val), |
|
2582 |
+ site_arrays[to_sa].family.norbs, |
|
2583 |
+ site_arrays[from_sa].family.norbs, |
|
2584 |
+ ), |
|
2585 |
+ ) |
|
2586 |
+ for (_, to_sa, from_sa), val in |
|
2587 |
+ zip(hopping_to_term_nr.keys(), hopping_term_values) |
|
2565 | 2588 |
] |
2566 | 2589 |
# Sort the hoppings in each term, and also sort the values |
2567 | 2590 |
# in the same way if they are a constant (as opposed to a callable). |
... | ... |
@@ -327,12 +327,14 @@ class _FunctionalOnsite: |
327 | 327 |
|
328 | 328 |
def __call__(self, site_range, site_offsets, *args): |
329 | 329 |
sites = self.sites |
330 |
- offset = self.site_ranges[site_range][0] |
|
330 |
+ offset, norbs, _ = self.site_ranges[site_range] |
|
331 | 331 |
try: |
332 | 332 |
ret = [self.onsite(sites[offset + i], *args) for i in site_offsets] |
333 | 333 |
except Exception as exc: |
334 | 334 |
_raise_user_error(exc, self.onsite, vectorized=False) |
335 |
- return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
335 |
+ |
|
336 |
+ expected_shape = (len(site_offsets), norbs, norbs) |
|
337 |
+ return _normalize_matrix_blocks(ret, expected_shape) |
|
336 | 338 |
|
337 | 339 |
|
338 | 340 |
class _VectorizedFunctionalOnsite: |
... | ... |
@@ -349,7 +351,9 @@ class _VectorizedFunctionalOnsite: |
349 | 351 |
ret = self.onsite(sites, *args) |
350 | 352 |
except Exception as exc: |
351 | 353 |
_raise_user_error(exc, self.onsite, vectorized=True) |
352 |
- return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
354 |
+ |
|
355 |
+ expected_shape = (len(sites), sites.family.norbs, sites.family.norbs) |
|
356 |
+ return _normalize_matrix_blocks(ret, expected_shape) |
|
353 | 357 |
|
354 | 358 |
|
355 | 359 |
class _FunctionalOnsiteNoTransform: |
... | ... |
@@ -359,12 +363,15 @@ class _FunctionalOnsiteNoTransform: |
359 | 363 |
self.site_ranges = site_ranges |
360 | 364 |
|
361 | 365 |
def __call__(self, site_range, site_offsets, *args): |
362 |
- site_ids = self.site_ranges[site_range][0] + site_offsets |
|
366 |
+ offset, norbs, _ = self.site_ranges[site_range] |
|
367 |
+ site_ids = offset + site_offsets |
|
363 | 368 |
try: |
364 | 369 |
ret = [self.onsite(id, *args) for id in site_ids] |
365 | 370 |
except Exception as exc: |
366 | 371 |
_raise_user_error(exc, self.onsite, vectorized=False) |
367 |
- return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
372 |
+ |
|
373 |
+ expected_shape = (len(site_offsets), norbs, norbs) |
|
374 |
+ return _normalize_matrix_blocks(ret, expected_shape) |
|
368 | 375 |
|
369 | 376 |
|
370 | 377 |
class _DictOnsite: |
... | ... |
@@ -376,7 +383,9 @@ class _DictOnsite: |
376 | 383 |
def __call__(self, site_range, site_offsets, *args): |
377 | 384 |
fam = self.range_family_map[site_range] |
378 | 385 |
ret = [self.onsite[fam]] * len(site_offsets) |
379 |
- return _normalize_matrix_blocks(ret, len(site_offsets)) |
|
386 |
+ |
|
387 |
+ expected_shape = (len(site_offsets), fam.norbs, fam.norbs) |
|
388 |
+ return _normalize_matrix_blocks(ret, expected_shape) |
|
380 | 389 |
|
381 | 390 |
|
382 | 391 |
def _normalize_onsite(syst, onsite, check_hermiticity): |
... | ... |
@@ -993,11 +1002,13 @@ cdef class _LocalOperator: |
993 | 1002 |
syst.hamiltonian(where[i, 0], where[i, 1], *args, params=params) |
994 | 1003 |
for i in which |
995 | 1004 |
] |
996 |
- data = _normalize_matrix_blocks(data, len(which)) |
|
997 |
- # Checks for data consistency |
|
1005 |
+ |
|
998 | 1006 |
(to_sr, from_sr) = term_id |
999 |
- to_norbs = syst.site_ranges[to_sr][1] |
|
1000 |
- from_norbs = syst.site_ranges[from_sr][1] |
|
1007 |
+ to_norbs, from_norbs = syst.site_ranges[to_sr][1], syst.site_ranges[from_sr][1] |
|
1008 |
+ expected_shape = (len(which), to_norbs, from_norbs) |
|
1009 |
+ data = _normalize_matrix_blocks(data, expected_shape) |
|
1010 |
+ # Checks for data consistency |
|
1011 |
+ |
|
1001 | 1012 |
_check_hams(data, to_norbs, from_norbs, is_onsite and check_hermiticity) |
1002 | 1013 |
|
1003 | 1014 |
return data |
... | ... |
@@ -742,13 +742,13 @@ def is_vectorized(syst): |
742 | 742 |
return isinstance(syst, (FiniteVectorizedSystem, InfiniteVectorizedSystem)) |
743 | 743 |
|
744 | 744 |
|
745 |
-def _normalize_matrix_blocks(blocks, expected_length): |
|
745 |
+def _normalize_matrix_blocks(blocks, expected_shape): |
|
746 | 746 |
"""Normalize a sequence of matrices into a single 3D numpy array |
747 | 747 |
|
748 | 748 |
Parameters |
749 | 749 |
---------- |
750 | 750 |
blocks : sequence of complex array-like |
751 |
- expected_length : int |
|
751 |
+ expected_shape : (int, int, int) |
|
752 | 752 |
""" |
753 | 753 |
try: |
754 | 754 |
blocks = np.asarray(blocks, dtype=complex) |
... | ... |
@@ -757,21 +757,19 @@ def _normalize_matrix_blocks(blocks, expected_length): |
757 | 757 |
"Matrix elements declared with incompatible shapes." |
758 | 758 |
) from None |
759 | 759 |
if len(blocks.shape) == 0: # scalar → broadcast to vector of 1x1 matrices |
760 |
- blocks = np.tile(blocks, (expected_length, 1, 1)) |
|
760 |
+ blocks = np.tile(blocks, (expected_shape[0], 1, 1)) |
|
761 | 761 |
elif len(blocks.shape) == 1: # vector → interpret as vector of 1x1 matrices |
762 | 762 |
blocks = blocks.reshape(-1, 1, 1) |
763 | 763 |
elif len(blocks.shape) == 2: # matrix → broadcast to vector of matrices |
764 |
- blocks = np.tile(blocks, (expected_length, 1, 1)) |
|
764 |
+ blocks = np.tile(blocks, (expected_shape[0], 1, 1)) |
|
765 | 765 |
|
766 |
- if len(blocks.shape) != 3: |
|
766 |
+ if blocks.shape != expected_shape: |
|
767 | 767 |
msg = ( |
768 |
- "Vectorized value functions must return an array of" |
|
769 |
- "scalars or an array of matrices." |
|
768 |
+ "Expected values of shape {}, but received values of shape {}" |
|
769 |
+ .format(expected_shape, blocks.shape) |
|
770 | 770 |
) |
771 | 771 |
raise ValueError(msg) |
772 |
- if blocks.shape[0] != expected_length: |
|
773 |
- raise ValueError("Value functions must return a single value per " |
|
774 |
- "onsite/hopping.") |
|
772 |
+ |
|
775 | 773 |
return blocks |
776 | 774 |
|
777 | 775 |
|