Browse code

separate hamiltonian value normalization/verification from sorting

Previously these 2 steps were done in one pass. Now we separate them
out in order to facilitate augnmenting the verification to include
checking the number of orbitals (in future commits).

Joseph Weston authored on 11/12/2019 17:06:43
Showing 1 changed files
... ...
@@ -2365,7 +2365,6 @@ def _sort_term(term, value):
2365 2365
     term = np.asarray(term)
2366 2366
 
2367 2367
     if not callable(value):
2368
-        value = system._normalize_matrix_blocks(value, len(term))
2369 2368
         # Ensure that values still correspond to the correct
2370 2369
         # sites in 'term' once the latter has been sorted.
2371 2370
         value = value[term.argsort()]
... ...
@@ -2444,8 +2443,14 @@ def _make_onsite_terms(builder, sites, site_arrays, term_offset):
2444 2443
         if const_val:
2445 2444
             vals = onsite_term_values[onsite_to_term_nr[key]]
2446 2445
             vals.append(val)
2447
-    # Sort the sites in each term, and normalize any constant
2448
-    # values to arrays of the correct dtype and shape.
2446
+    # Normalize any constant values and check that the shapes are consistent.
2447
+    onsite_term_values = [
2448
+        val if callable(val)
2449
+        else system._normalize_matrix_blocks(val, len(val))
2450
+        for val in onsite_term_values
2451
+    ]
2452
+    # Sort the sites in each term, and also sort the values
2453
+    # in the same way if they are a constant (as opposed to a callable).
2449 2454
     onsite_subgraphs, onsite_term_values = zip(*(
2450 2455
         _sort_term(term, value)
2451 2456
         for term, value in
... ...
@@ -2550,9 +2555,16 @@ def _make_hopping_terms(builder, graph, sites, site_arrays, cell_size, term_offs
2550 2555
             if const_val:
2551 2556
                 vals = hopping_term_values[hopping_to_term_nr[key]]
2552 2557
                 vals.append(val)
2553
-    # Sort the hoppings in each term, and normalize any constant
2554
-    # values to arrays of the correct dtype and shape.
2558
+
2555 2559
     if hopping_subgraphs:
2560
+        # Normalize any constant values and check that the shapes are consistent.
2561
+        hopping_term_values = [
2562
+            val if callable(val)
2563
+            else system._normalize_matrix_blocks(val, len(val))
2564
+            for val in hopping_term_values
2565
+        ]
2566
+        # Sort the hoppings in each term, and also sort the values
2567
+        # in the same way if they are a constant (as opposed to a callable).
2556 2568
         hopping_subgraphs, hopping_term_values = zip(*(
2557 2569
             _sort_hopping_term(term, value)
2558 2570
             for term, value in