Previously these 2 steps were done in one pass. Now we separate them
out in order to facilitate augnmenting the verification to include
checking the number of orbitals (in future commits).
... | ... |
@@ -2365,7 +2365,6 @@ def _sort_term(term, value): |
2365 | 2365 |
term = np.asarray(term) |
2366 | 2366 |
|
2367 | 2367 |
if not callable(value): |
2368 |
- value = system._normalize_matrix_blocks(value, len(term)) |
|
2369 | 2368 |
# Ensure that values still correspond to the correct |
2370 | 2369 |
# sites in 'term' once the latter has been sorted. |
2371 | 2370 |
value = value[term.argsort()] |
... | ... |
@@ -2444,8 +2443,14 @@ def _make_onsite_terms(builder, sites, site_arrays, term_offset): |
2444 | 2443 |
if const_val: |
2445 | 2444 |
vals = onsite_term_values[onsite_to_term_nr[key]] |
2446 | 2445 |
vals.append(val) |
2447 |
- # Sort the sites in each term, and normalize any constant |
|
2448 |
- # values to arrays of the correct dtype and shape. |
|
2446 |
+ # Normalize any constant values and check that the shapes are consistent. |
|
2447 |
+ onsite_term_values = [ |
|
2448 |
+ val if callable(val) |
|
2449 |
+ else system._normalize_matrix_blocks(val, len(val)) |
|
2450 |
+ for val in onsite_term_values |
|
2451 |
+ ] |
|
2452 |
+ # Sort the sites in each term, and also sort the values |
|
2453 |
+ # in the same way if they are a constant (as opposed to a callable). |
|
2449 | 2454 |
onsite_subgraphs, onsite_term_values = zip(*( |
2450 | 2455 |
_sort_term(term, value) |
2451 | 2456 |
for term, value in |
... | ... |
@@ -2550,9 +2555,16 @@ def _make_hopping_terms(builder, graph, sites, site_arrays, cell_size, term_offs |
2550 | 2555 |
if const_val: |
2551 | 2556 |
vals = hopping_term_values[hopping_to_term_nr[key]] |
2552 | 2557 |
vals.append(val) |
2553 |
- # Sort the hoppings in each term, and normalize any constant |
|
2554 |
- # values to arrays of the correct dtype and shape. |
|
2558 |
+ |
|
2555 | 2559 |
if hopping_subgraphs: |
2560 |
+ # Normalize any constant values and check that the shapes are consistent. |
|
2561 |
+ hopping_term_values = [ |
|
2562 |
+ val if callable(val) |
|
2563 |
+ else system._normalize_matrix_blocks(val, len(val)) |
|
2564 |
+ for val in hopping_term_values |
|
2565 |
+ ] |
|
2566 |
+ # Sort the hoppings in each term, and also sort the values |
|
2567 |
+ # in the same way if they are a constant (as opposed to a callable). |
|
2556 | 2568 |
hopping_subgraphs, hopping_term_values = zip(*( |
2557 | 2569 |
_sort_hopping_term(term, value) |
2558 | 2570 |
for term, value in |