143 lines
4.8 KiB
Python
143 lines
4.8 KiB
Python
|
from .matexpr import MatrixExpr
|
||
|
from .special import Identity
|
||
|
from sympy.core import S
|
||
|
from sympy.core.expr import ExprBuilder
|
||
|
from sympy.core.cache import cacheit
|
||
|
from sympy.core.power import Pow
|
||
|
from sympy.core.sympify import _sympify
|
||
|
from sympy.matrices import MatrixBase
|
||
|
from sympy.matrices.common import NonSquareMatrixError
|
||
|
|
||
|
|
||
|
class MatPow(MatrixExpr):
|
||
|
def __new__(cls, base, exp, evaluate=False, **options):
|
||
|
base = _sympify(base)
|
||
|
if not base.is_Matrix:
|
||
|
raise TypeError("MatPow base should be a matrix")
|
||
|
|
||
|
if base.is_square is False:
|
||
|
raise NonSquareMatrixError("Power of non-square matrix %s" % base)
|
||
|
|
||
|
exp = _sympify(exp)
|
||
|
obj = super().__new__(cls, base, exp)
|
||
|
|
||
|
if evaluate:
|
||
|
obj = obj.doit(deep=False)
|
||
|
|
||
|
return obj
|
||
|
|
||
|
@property
|
||
|
def base(self):
|
||
|
return self.args[0]
|
||
|
|
||
|
@property
|
||
|
def exp(self):
|
||
|
return self.args[1]
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return self.base.shape
|
||
|
|
||
|
@cacheit
|
||
|
def _get_explicit_matrix(self):
|
||
|
return self.base.as_explicit()**self.exp
|
||
|
|
||
|
def _entry(self, i, j, **kwargs):
|
||
|
from sympy.matrices.expressions import MatMul
|
||
|
A = self.doit()
|
||
|
if isinstance(A, MatPow):
|
||
|
# We still have a MatPow, make an explicit MatMul out of it.
|
||
|
if A.exp.is_Integer and A.exp.is_positive:
|
||
|
A = MatMul(*[A.base for k in range(A.exp)])
|
||
|
elif not self._is_shape_symbolic():
|
||
|
return A._get_explicit_matrix()[i, j]
|
||
|
else:
|
||
|
# Leave the expression unevaluated:
|
||
|
from sympy.matrices.expressions.matexpr import MatrixElement
|
||
|
return MatrixElement(self, i, j)
|
||
|
return A[i, j]
|
||
|
|
||
|
def doit(self, **hints):
|
||
|
if hints.get('deep', True):
|
||
|
base, exp = (arg.doit(**hints) for arg in self.args)
|
||
|
else:
|
||
|
base, exp = self.args
|
||
|
|
||
|
# combine all powers, e.g. (A ** 2) ** 3 -> A ** 6
|
||
|
while isinstance(base, MatPow):
|
||
|
exp *= base.args[1]
|
||
|
base = base.args[0]
|
||
|
|
||
|
if isinstance(base, MatrixBase):
|
||
|
# Delegate
|
||
|
return base ** exp
|
||
|
|
||
|
# Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them
|
||
|
if exp == S.One:
|
||
|
return base
|
||
|
if exp == S.Zero:
|
||
|
return Identity(base.rows)
|
||
|
if exp == S.NegativeOne:
|
||
|
from sympy.matrices.expressions import Inverse
|
||
|
return Inverse(base).doit(**hints)
|
||
|
|
||
|
eval_power = getattr(base, '_eval_power', None)
|
||
|
if eval_power is not None:
|
||
|
return eval_power(exp)
|
||
|
|
||
|
return MatPow(base, exp)
|
||
|
|
||
|
def _eval_transpose(self):
|
||
|
base, exp = self.args
|
||
|
return MatPow(base.T, exp)
|
||
|
|
||
|
def _eval_derivative(self, x):
|
||
|
return Pow._eval_derivative(self, x)
|
||
|
|
||
|
def _eval_derivative_matrix_lines(self, x):
|
||
|
from sympy.tensor.array.expressions.array_expressions import ArrayContraction
|
||
|
from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
|
||
|
from .matmul import MatMul
|
||
|
from .inverse import Inverse
|
||
|
exp = self.exp
|
||
|
if self.base.shape == (1, 1) and not exp.has(x):
|
||
|
lr = self.base._eval_derivative_matrix_lines(x)
|
||
|
for i in lr:
|
||
|
subexpr = ExprBuilder(
|
||
|
ArrayContraction,
|
||
|
[
|
||
|
ExprBuilder(
|
||
|
ArrayTensorProduct,
|
||
|
[
|
||
|
Identity(1),
|
||
|
i._lines[0],
|
||
|
exp*self.base**(exp-1),
|
||
|
i._lines[1],
|
||
|
Identity(1),
|
||
|
]
|
||
|
),
|
||
|
(0, 3, 4), (5, 7, 8)
|
||
|
],
|
||
|
validator=ArrayContraction._validate
|
||
|
)
|
||
|
i._first_pointer_parent = subexpr.args[0].args
|
||
|
i._first_pointer_index = 0
|
||
|
i._second_pointer_parent = subexpr.args[0].args
|
||
|
i._second_pointer_index = 4
|
||
|
i._lines = [subexpr]
|
||
|
return lr
|
||
|
if (exp > 0) == True:
|
||
|
newexpr = MatMul.fromiter([self.base for i in range(exp)])
|
||
|
elif (exp == -1) == True:
|
||
|
return Inverse(self.base)._eval_derivative_matrix_lines(x)
|
||
|
elif (exp < 0) == True:
|
||
|
newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)])
|
||
|
elif (exp == 0) == True:
|
||
|
return self.doit()._eval_derivative_matrix_lines(x)
|
||
|
else:
|
||
|
raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x))
|
||
|
return newexpr._eval_derivative_matrix_lines(x)
|
||
|
|
||
|
def _eval_inverse(self):
|
||
|
return MatPow(self.base, -self.exp)
|