Browse code

improve error message in _normalize_matrix_blocks

Joseph Weston authored on 12/12/2019 15:13:34
Showing 1 changed files
... ...
@@ -759,17 +759,23 @@ def _normalize_matrix_blocks(blocks, expected_shape, *, calling_function=None):
759 759
         raise ValueError(
760 760
             "Matrix elements declared with incompatible shapes."
761 761
         ) from None
762
+    original_shape = blocks.shape
763
+    was_broadcast = True  # Did the shape get broadcasted to a more general one?
762 764
     if len(blocks.shape) == 0:  # scalar → broadcast to vector of 1x1 matrices
763 765
         blocks = np.tile(blocks, (expected_shape[0], 1, 1))
764 766
     elif len(blocks.shape) == 1:  # vector → interpret as vector of 1x1 matrices
765 767
         blocks = blocks.reshape(-1, 1, 1)
766 768
     elif len(blocks.shape) == 2:  # matrix → broadcast to vector of matrices
767 769
         blocks = np.tile(blocks, (expected_shape[0], 1, 1))
770
+    else:
771
+        was_broadcast = False
768 772
 
769 773
     if blocks.shape != expected_shape:
770 774
         msg = (
771 775
             "Expected values of shape {}, but received values of shape {}"
772 776
                 .format(expected_shape, blocks.shape),
777
+            "(broadcasted from shape {})".format(original_shape)
778
+                if was_broadcast else "",
773 779
             "when evaluating {}".format(calling_function.__name__)
774 780
                 if callable(calling_function) else "",
775 781
         )