...
|
...
|
@@ -759,17 +759,23 @@ def _normalize_matrix_blocks(blocks, expected_shape, *, calling_function=None):
|
759
|
759
|
raise ValueError(
|
760
|
760
|
"Matrix elements declared with incompatible shapes."
|
761
|
761
|
) from None
|
|
762
|
+ original_shape = blocks.shape
|
|
763
|
+ was_broadcast = True # Did the shape get broadcasted to a more general one?
|
762
|
764
|
if len(blocks.shape) == 0: # scalar → broadcast to vector of 1x1 matrices
|
763
|
765
|
blocks = np.tile(blocks, (expected_shape[0], 1, 1))
|
764
|
766
|
elif len(blocks.shape) == 1: # vector → interpret as vector of 1x1 matrices
|
765
|
767
|
blocks = blocks.reshape(-1, 1, 1)
|
766
|
768
|
elif len(blocks.shape) == 2: # matrix → broadcast to vector of matrices
|
767
|
769
|
blocks = np.tile(blocks, (expected_shape[0], 1, 1))
|
|
770
|
+ else:
|
|
771
|
+ was_broadcast = False
|
768
|
772
|
|
769
|
773
|
if blocks.shape != expected_shape:
|
770
|
774
|
msg = (
|
771
|
775
|
"Expected values of shape {}, but received values of shape {}"
|
772
|
776
|
.format(expected_shape, blocks.shape),
|
|
777
|
+ "(broadcasted from shape {})".format(original_shape)
|
|
778
|
+ if was_broadcast else "",
|
773
|
779
|
"when evaluating {}".format(calling_function.__name__)
|
774
|
780
|
if callable(calling_function) else "",
|
775
|
781
|
)
|