3RNN/Lib/site-packages/scipy/linalg/_matfuncs_sqrtm.py

215 lines
6.7 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
"""
Matrix square root for general matrices and for upper triangular matrices.
This module exists to avoid cyclic imports.
"""
__all__ = ['sqrtm']
import numpy as np
from scipy._lib._util import _asarray_validated
# Local imports
from ._misc import norm
from .lapack import ztrsyl, dtrsyl
from ._decomp_schur import schur, rsf2csf
class SqrtmError(np.linalg.LinAlgError):
pass
from ._matfuncs_sqrtm_triu import within_block_loop # noqa: E402
def _sqrtm_triu(T, blocksize=64):
"""
Matrix square root of an upper triangular matrix.
This is a helper function for `sqrtm` and `logm`.
Parameters
----------
T : (N, N) array_like upper triangular
Matrix whose square root to evaluate
blocksize : int, optional
If the blocksize is not degenerate with respect to the
size of the input array, then use a blocked algorithm. (Default: 64)
Returns
-------
sqrtm : (N, N) ndarray
Value of the sqrt function at `T`
References
----------
.. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
"Blocked Schur Algorithms for Computing the Matrix Square Root,
Lecture Notes in Computer Science, 7782. pp. 171-182.
"""
T_diag = np.diag(T)
keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0
# Cast to complex as necessary + ensure double precision
if not keep_it_real:
T = np.asarray(T, dtype=np.complex128, order="C")
T_diag = np.asarray(T_diag, dtype=np.complex128)
else:
T = np.asarray(T, dtype=np.float64, order="C")
T_diag = np.asarray(T_diag, dtype=np.float64)
R = np.diag(np.sqrt(T_diag))
# Compute the number of blocks to use; use at least one block.
n, n = T.shape
nblocks = max(n // blocksize, 1)
# Compute the smaller of the two sizes of blocks that
# we will actually use, and compute the number of large blocks.
bsmall, nlarge = divmod(n, nblocks)
blarge = bsmall + 1
nsmall = nblocks - nlarge
if nsmall * bsmall + nlarge * blarge != n:
raise Exception('internal inconsistency')
# Define the index range covered by each block.
start_stop_pairs = []
start = 0
for count, size in ((nsmall, bsmall), (nlarge, blarge)):
for i in range(count):
start_stop_pairs.append((start, start + size))
start += size
# Within-block interactions (Cythonized)
try:
within_block_loop(R, T, start_stop_pairs, nblocks)
except RuntimeError as e:
raise SqrtmError(*e.args) from e
# Between-block interactions (Cython would give no significant speedup)
for j in range(nblocks):
jstart, jstop = start_stop_pairs[j]
for i in range(j-1, -1, -1):
istart, istop = start_stop_pairs[i]
S = T[istart:istop, jstart:jstop]
if j - i > 1:
S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
jstart:jstop])
# Invoke LAPACK.
# For more details, see the solve_sylvester implementation
# and the fortran dtrsyl and ztrsyl docs.
Rii = R[istart:istop, istart:istop]
Rjj = R[jstart:jstop, jstart:jstop]
if keep_it_real:
x, scale, info = dtrsyl(Rii, Rjj, S)
else:
x, scale, info = ztrsyl(Rii, Rjj, S)
R[istart:istop, jstart:jstop] = x * scale
# Return the matrix square root.
return R
def sqrtm(A, disp=True, blocksize=64):
"""
Matrix square root.
Parameters
----------
A : (N, N) array_like
Matrix whose square root to evaluate
disp : bool, optional
Print warning if error in the result is estimated large
instead of returning estimated error. (Default: True)
blocksize : integer, optional
If the blocksize is not degenerate with respect to the
size of the input array, then use a blocked algorithm. (Default: 64)
Returns
-------
sqrtm : (N, N) ndarray
Value of the sqrt function at `A`. The dtype is float or complex.
The precision (data size) is determined based on the precision of
input `A`. When the dtype is float, the precision is the same as `A`.
When the dtype is complex, the precision is double that of `A`. The
precision might be clipped by each dtype precision range.
errest : float
(if disp == False)
Frobenius norm of the estimated error, ||err||_F / ||A||_F
References
----------
.. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
"Blocked Schur Algorithms for Computing the Matrix Square Root,
Lecture Notes in Computer Science, 7782. pp. 171-182.
Examples
--------
>>> import numpy as np
>>> from scipy.linalg import sqrtm
>>> a = np.array([[1.0, 3.0], [1.0, 4.0]])
>>> r = sqrtm(a)
>>> r
array([[ 0.75592895, 1.13389342],
[ 0.37796447, 1.88982237]])
>>> r.dot(r)
array([[ 1., 3.],
[ 1., 4.]])
"""
byte_size = np.asarray(A).dtype.itemsize
A = _asarray_validated(A, check_finite=True, as_inexact=True)
if len(A.shape) != 2:
raise ValueError("Non-matrix input to matrix function.")
if blocksize < 1:
raise ValueError("The blocksize should be at least 1.")
keep_it_real = np.isrealobj(A)
if keep_it_real:
T, Z = schur(A)
d0 = np.diagonal(T)
d1 = np.diagonal(T, -1)
eps = np.finfo(T.dtype).eps
needs_conversion = abs(d1) > eps * (abs(d0[1:]) + abs(d0[:-1]))
if needs_conversion.any():
T, Z = rsf2csf(T, Z)
else:
T, Z = schur(A, output='complex')
failflag = False
try:
R = _sqrtm_triu(T, blocksize=blocksize)
ZH = np.conjugate(Z).T
X = Z.dot(R).dot(ZH)
if not np.iscomplexobj(X):
# float byte size range: f2 ~ f16
X = X.astype(f"f{np.clip(byte_size, 2, 16)}", copy=False)
else:
# complex byte size range: c8 ~ c32.
# c32(complex256) might not be supported in some environments.
if hasattr(np, 'complex256'):
X = X.astype(f"c{np.clip(byte_size*2, 8, 32)}", copy=False)
else:
X = X.astype(f"c{np.clip(byte_size*2, 8, 16)}", copy=False)
except SqrtmError:
failflag = True
X = np.empty_like(A)
X.fill(np.nan)
if disp:
if failflag:
print("Failed to find a square root.")
return X
else:
try:
arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro')
except ValueError:
# NaNs in matrix
arg2 = np.inf
return X, arg2