from collections.abc import Iterable from functools import singledispatch from sympy.core.expr import Expr from sympy.core.mul import Mul from sympy.core.singleton import S from sympy.core.sympify import sympify from sympy.core.parameters import global_parameters class TensorProduct(Expr): """ Generic class for tensor products. """ is_number = False def __new__(cls, *args, **kwargs): from sympy.tensor.array import NDimArray, tensorproduct, Array from sympy.matrices.expressions.matexpr import MatrixExpr from sympy.matrices.matrices import MatrixBase from sympy.strategies import flatten args = [sympify(arg) for arg in args] evaluate = kwargs.get("evaluate", global_parameters.evaluate) if not evaluate: obj = Expr.__new__(cls, *args) return obj arrays = [] other = [] scalar = S.One for arg in args: if isinstance(arg, (Iterable, MatrixBase, NDimArray)): arrays.append(Array(arg)) elif isinstance(arg, (MatrixExpr,)): other.append(arg) else: scalar *= arg coeff = scalar*tensorproduct(*arrays) if len(other) == 0: return coeff if coeff != 1: newargs = [coeff] + other else: newargs = other obj = Expr.__new__(cls, *newargs, **kwargs) return flatten(obj) def rank(self): return len(self.shape) def _get_args_shapes(self): from sympy.tensor.array import Array return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args] @property def shape(self): shape_list = self._get_args_shapes() return sum(shape_list, ()) def __getitem__(self, index): index = iter(index) return Mul.fromiter( arg.__getitem__(tuple(next(index) for i in shp)) for arg, shp in zip(self.args, self._get_args_shapes()) ) @singledispatch def shape(expr): """ Return the shape of the *expr* as a tuple. *expr* should represent suitable object such as matrix or array. Parameters ========== expr : SymPy object having ``MatrixKind`` or ``ArrayKind``. Raises ====== NoShapeError : Raised when object with wrong kind is passed. Examples ======== This function returns the shape of any object representing matrix or array. >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral >>> from sympy.abc import x >>> A = Array([1, 2]) >>> shape(A) (2,) >>> shape(Integral(A, x)) (2,) >>> M = ImmutableDenseMatrix([1, 2]) >>> shape(M) (2, 1) >>> shape(Integral(M, x)) (2, 1) You can support new type by dispatching. >>> from sympy import Expr >>> class NewExpr(Expr): ... pass >>> @shape.register(NewExpr) ... def _(expr): ... return shape(expr.args[0]) >>> shape(NewExpr(M)) (2, 1) If unsuitable expression is passed, ``NoShapeError()`` will be raised. >>> shape(Integral(x, x)) Traceback (most recent call last): ... sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x) Notes ===== Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape`` property which returns its shape, but it cannot be used for non-array classes containing array. This function returns the shape of any registered object representing array. """ if hasattr(expr, "shape"): return expr.shape raise NoShapeError( "%s does not have shape, or its type is not registered to shape()." % expr) class NoShapeError(Exception): """ Raised when ``shape()`` is called on non-array object. This error can be imported from ``sympy.tensor.functions``. Examples ======== >>> from sympy import shape >>> from sympy.abc import x >>> shape(x) Traceback (most recent call last): ... sympy.tensor.functions.NoShapeError: shape() called on non-array object: x """ pass