Browse code

update '_normalize_matrix_blocks' to check norbs as well as length

Joseph Weston authored on 11/12/2019 13:18:48
Showing 3 changed files
... ...
@@ -2122,7 +2122,12 @@ class _VectorizedFinalizedBuilderMixin(_FinalizedBuilderMixin):
2122 2122
             except Exception as exc:
2123 2123
                 _raise_user_error(exc, val)
2124 2124
 
2125
-        ham = system._normalize_matrix_blocks(ham, len(to_site_array))
2125
+        expected_shape = (
2126
+            len(to_site_array),
2127
+            to_family.norbs,
2128
+            from_family.norbs if not is_onsite else to_family.norbs,
2129
+        )
2130
+        ham = system._normalize_matrix_blocks(ham, expected_shape)
2126 2131
 
2127 2132
         return ham
2128 2133
 
... ...
@@ -2446,8 +2451,17 @@ def _make_onsite_terms(builder, sites, site_arrays, term_offset):
2446 2451
     # Normalize any constant values and check that the shapes are consistent.
2447 2452
     onsite_term_values = [
2448 2453
         val if callable(val)
2449
-        else system._normalize_matrix_blocks(val, len(val))
2450
-        for val in onsite_term_values
2454
+        else
2455
+            system._normalize_matrix_blocks(
2456
+                val,
2457
+                (
2458
+                    len(val),
2459
+                    site_arrays[sa].family.norbs,
2460
+                    site_arrays[sa].family.norbs,
2461
+                ),
2462
+            )
2463
+        for (_, sa), val in
2464
+        zip(onsite_to_term_nr.keys(), onsite_term_values)
2451 2465
     ]
2452 2466
     # Sort the sites in each term, and also sort the values
2453 2467
     # in the same way if they are a constant (as opposed to a callable).
... ...
@@ -2560,8 +2574,17 @@ def _make_hopping_terms(builder, graph, sites, site_arrays, cell_size, term_offs
2560 2574
         # Normalize any constant values and check that the shapes are consistent.
2561 2575
         hopping_term_values = [
2562 2576
             val if callable(val)
2563
-            else system._normalize_matrix_blocks(val, len(val))
2564
-            for val in hopping_term_values
2577
+            else
2578
+                system._normalize_matrix_blocks(
2579
+                    val,
2580
+                    (
2581
+                        len(val),
2582
+                        site_arrays[to_sa].family.norbs,
2583
+                        site_arrays[from_sa].family.norbs,
2584
+                    ),
2585
+                )
2586
+            for (_, to_sa, from_sa), val in
2587
+            zip(hopping_to_term_nr.keys(), hopping_term_values)
2565 2588
         ]
2566 2589
         # Sort the hoppings in each term, and also sort the values
2567 2590
         # in the same way if they are a constant (as opposed to a callable).
... ...
@@ -327,12 +327,14 @@ class _FunctionalOnsite:
327 327
 
328 328
     def __call__(self, site_range, site_offsets, *args):
329 329
         sites = self.sites
330
-        offset = self.site_ranges[site_range][0]
330
+        offset, norbs, _ = self.site_ranges[site_range]
331 331
         try:
332 332
             ret = [self.onsite(sites[offset + i], *args) for i in site_offsets]
333 333
         except Exception as exc:
334 334
             _raise_user_error(exc, self.onsite, vectorized=False)
335
-        return _normalize_matrix_blocks(ret, len(site_offsets))
335
+
336
+        expected_shape = (len(site_offsets), norbs, norbs)
337
+        return _normalize_matrix_blocks(ret, expected_shape)
336 338
 
337 339
 
338 340
 class _VectorizedFunctionalOnsite:
... ...
@@ -349,7 +351,9 @@ class _VectorizedFunctionalOnsite:
349 351
             ret = self.onsite(sites, *args)
350 352
         except Exception as exc:
351 353
             _raise_user_error(exc, self.onsite, vectorized=True)
352
-        return _normalize_matrix_blocks(ret, len(site_offsets))
354
+
355
+        expected_shape = (len(sites), sites.family.norbs, sites.family.norbs)
356
+        return _normalize_matrix_blocks(ret, expected_shape)
353 357
 
354 358
 
355 359
 class _FunctionalOnsiteNoTransform:
... ...
@@ -359,12 +363,15 @@ class _FunctionalOnsiteNoTransform:
359 363
         self.site_ranges = site_ranges
360 364
 
361 365
     def __call__(self, site_range, site_offsets, *args):
362
-        site_ids = self.site_ranges[site_range][0] + site_offsets
366
+        offset, norbs, _ = self.site_ranges[site_range]
367
+        site_ids = offset + site_offsets
363 368
         try:
364 369
             ret = [self.onsite(id, *args) for id in site_ids]
365 370
         except Exception as exc:
366 371
             _raise_user_error(exc, self.onsite, vectorized=False)
367
-        return _normalize_matrix_blocks(ret, len(site_offsets))
372
+
373
+        expected_shape = (len(site_offsets), norbs, norbs)
374
+        return _normalize_matrix_blocks(ret, expected_shape)
368 375
 
369 376
 
370 377
 class _DictOnsite:
... ...
@@ -376,7 +383,9 @@ class _DictOnsite:
376 383
     def __call__(self, site_range, site_offsets, *args):
377 384
         fam = self.range_family_map[site_range]
378 385
         ret = [self.onsite[fam]] * len(site_offsets)
379
-        return _normalize_matrix_blocks(ret, len(site_offsets))
386
+
387
+        expected_shape = (len(site_offsets), fam.norbs, fam.norbs)
388
+        return _normalize_matrix_blocks(ret, expected_shape)
380 389
 
381 390
 
382 391
 def _normalize_onsite(syst, onsite, check_hermiticity):
... ...
@@ -993,11 +1002,13 @@ cdef class _LocalOperator:
993 1002
                         syst.hamiltonian(where[i, 0], where[i, 1], *args, params=params)
994 1003
                         for i in which
995 1004
                     ]
996
-                data = _normalize_matrix_blocks(data, len(which))
997
-                # Checks for data consistency
1005
+
998 1006
                 (to_sr, from_sr) = term_id
999
-                to_norbs = syst.site_ranges[to_sr][1]
1000
-                from_norbs = syst.site_ranges[from_sr][1]
1007
+                to_norbs, from_norbs = syst.site_ranges[to_sr][1], syst.site_ranges[from_sr][1]
1008
+                expected_shape = (len(which), to_norbs, from_norbs)
1009
+                data = _normalize_matrix_blocks(data, expected_shape)
1010
+                # Checks for data consistency
1011
+
1001 1012
                 _check_hams(data, to_norbs, from_norbs, is_onsite and check_hermiticity)
1002 1013
 
1003 1014
                 return data
... ...
@@ -742,13 +742,13 @@ def is_vectorized(syst):
742 742
     return isinstance(syst, (FiniteVectorizedSystem, InfiniteVectorizedSystem))
743 743
 
744 744
 
745
-def _normalize_matrix_blocks(blocks, expected_length):
745
+def _normalize_matrix_blocks(blocks, expected_shape):
746 746
     """Normalize a sequence of matrices into a single 3D numpy array
747 747
 
748 748
     Parameters
749 749
     ----------
750 750
     blocks : sequence of complex array-like
751
-    expected_length : int
751
+    expected_shape : (int, int, int)
752 752
     """
753 753
     try:
754 754
         blocks = np.asarray(blocks, dtype=complex)
... ...
@@ -757,21 +757,19 @@ def _normalize_matrix_blocks(blocks, expected_length):
757 757
             "Matrix elements declared with incompatible shapes."
758 758
         ) from None
759 759
     if len(blocks.shape) == 0:  # scalar → broadcast to vector of 1x1 matrices
760
-        blocks = np.tile(blocks, (expected_length, 1, 1))
760
+        blocks = np.tile(blocks, (expected_shape[0], 1, 1))
761 761
     elif len(blocks.shape) == 1:  # vector → interpret as vector of 1x1 matrices
762 762
         blocks = blocks.reshape(-1, 1, 1)
763 763
     elif len(blocks.shape) == 2:  # matrix → broadcast to vector of matrices
764
-        blocks = np.tile(blocks, (expected_length, 1, 1))
764
+        blocks = np.tile(blocks, (expected_shape[0], 1, 1))
765 765
 
766
-    if len(blocks.shape) != 3:
766
+    if blocks.shape != expected_shape:
767 767
         msg = (
768
-            "Vectorized value functions must return an array of"
769
-            "scalars or an array of matrices."
768
+            "Expected values of shape {}, but received values of shape {}"
769
+            .format(expected_shape, blocks.shape)
770 770
         )
771 771
         raise ValueError(msg)
772
-    if blocks.shape[0] != expected_length:
773
-        raise ValueError("Value functions must return a single value per "
774
-                         "onsite/hopping.")
772
+
775 773
     return blocks
776 774
 
777 775