3RNN/Lib/site-packages/scipy/linalg/_decomp_lu.py
2024-05-26 19:49:15 +02:00

375 lines
12 KiB
Python

"""LU decomposition functions."""
from warnings import warn
from numpy import asarray, asarray_chkfinite
import numpy as np
from itertools import product
# Local imports
from ._misc import _datacopied, LinAlgWarning
from .lapack import get_lapack_funcs
from ._decomp_lu_cython import lu_dispatcher
lapack_cast_dict = {x: ''.join([y for y in 'fdFD' if np.can_cast(x, y)])
for x in np.typecodes['All']}
__all__ = ['lu', 'lu_solve', 'lu_factor']
def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : (M, N) array_like
Matrix to decompose
overwrite_a : bool, optional
Whether to overwrite data in A (may increase performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
lu : (M, N) ndarray
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv : (K,) ndarray
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
Of shape ``(K,)``, with ``K = min(M, N)``.
See Also
--------
lu : gives lu factorization in more user-friendly format
lu_solve : solve an equation system using the LU factorization of a matrix
Notes
-----
This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
:func:`lu`, it outputs the L and U factors into a single array
and returns pivot indices instead of a permutation matrix.
While the underlying ``*GETRF`` routines return 1-based pivot indices, the
``piv`` array returned by ``lu_factor`` contains 0-based indices.
Examples
--------
>>> import numpy as np
>>> from scipy.linalg import lu_factor
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> lu, piv = lu_factor(A)
>>> piv
array([2, 2, 3, 3], dtype=int32)
Convert LAPACK's ``piv`` array to NumPy index and test the permutation
>>> def pivot_to_permutation(piv):
... perm = np.arange(len(piv))
... for i in range(len(piv)):
... perm[i], perm[piv[i]] = perm[piv[i]], perm[i]
... return perm
...
>>> p_inv = pivot_to_permutation(piv)
>>> p_inv
array([2, 0, 3, 1])
>>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
>>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4)))
True
The P matrix in P L U is defined by the inverse permutation and
can be recovered using argsort:
>>> p = np.argsort(p_inv)
>>> p
array([1, 3, 0, 2])
>>> np.allclose(A - L[p] @ U, np.zeros((4, 4)))
True
or alternatively:
>>> P = np.eye(4)[p]
>>> np.allclose(A - P @ L @ U, np.zeros((4, 4)))
True
"""
if check_finite:
a1 = asarray_chkfinite(a)
else:
a1 = asarray(a)
overwrite_a = overwrite_a or (_datacopied(a1, a))
getrf, = get_lapack_funcs(('getrf',), (a1,))
lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %dth argument of '
'internal getrf (lu_factor)' % -info)
if info > 0:
warn("Diagonal number %d is exactly zero. Singular matrix." % info,
LinAlgWarning, stacklevel=2)
return lu, piv
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""Solve an equation system, a x = b, given the LU factorization of a
Parameters
----------
(lu, piv)
Factorization of the coefficient matrix a, as given by lu_factor.
In particular piv are 0-indexed pivot indices.
b : array
Right-hand side
trans : {0, 1, 2}, optional
Type of system to solve:
===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
overwrite_b : bool, optional
Whether to overwrite data in b (may increase performance)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
x : array
Solution to the system
See Also
--------
lu_factor : LU factorize a matrix
Examples
--------
>>> import numpy as np
>>> from scipy.linalg import lu_factor, lu_solve
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> b = np.array([1, 1, 1, 1])
>>> lu, piv = lu_factor(A)
>>> x = lu_solve((lu, piv), b)
>>> np.allclose(A @ x - b, np.zeros((4,)))
True
"""
(lu, piv) = lu_and_piv
if check_finite:
b1 = asarray_chkfinite(b)
else:
b1 = asarray(b)
overwrite_b = overwrite_b or _datacopied(b1, b)
if lu.shape[0] != b1.shape[0]:
raise ValueError(f"Shapes of lu {lu.shape} and b {b1.shape} are incompatible")
getrs, = get_lapack_funcs(('getrs',), (lu, b1))
x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
if info == 0:
return x
raise ValueError('illegal value in %dth argument of internal gesv|posv'
% -info)
def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
p_indices=False):
"""
Compute LU decomposition of a matrix with partial pivoting.
The decomposition satisfies::
A = P @ L @ U
where ``P`` is a permutation matrix, ``L`` lower triangular with unit
diagonal elements, and ``U`` upper triangular. If `permute_l` is set to
``True`` then ``L`` is returned already permuted and hence satisfying
``A = L @ U``.
Parameters
----------
a : (M, N) array_like
Array to decompose
permute_l : bool, optional
Perform the multiplication P*L (Default: do not permute)
overwrite_a : bool, optional
Whether to overwrite data in a (may improve performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
p_indices : bool, optional
If ``True`` the permutation information is returned as row indices.
The default is ``False`` for backwards-compatibility reasons.
Returns
-------
**(If `permute_l` is ``False``)**
p : (..., M, M) ndarray
Permutation arrays or vectors depending on `p_indices`
l : (..., M, K) ndarray
Lower triangular or trapezoidal array with unit diagonal.
``K = min(M, N)``
u : (..., K, N) ndarray
Upper triangular or trapezoidal array
**(If `permute_l` is ``True``)**
pl : (..., M, K) ndarray
Permuted L matrix.
``K = min(M, N)``
u : (..., K, N) ndarray
Upper triangular or trapezoidal array
Notes
-----
Permutation matrices are costly since they are nothing but row reorder of
``L`` and hence indices are strongly recommended to be used instead if the
permutation is required. The relation in the 2D case then becomes simply
``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
to avoid complicated indexing tricks.
In 2D case, if one has the indices however, for some reason, the
permutation matrix is still needed then it can be constructed by
``np.eye(M)[P, :]``.
Examples
--------
>>> import numpy as np
>>> from scipy.linalg import lu
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> p, l, u = lu(A)
>>> np.allclose(A, p @ l @ u)
True
>>> p # Permutation matrix
array([[0., 1., 0., 0.], # Row index 1
[0., 0., 0., 1.], # Row index 3
[1., 0., 0., 0.], # Row index 0
[0., 0., 1., 0.]]) # Row index 2
>>> p, _, _ = lu(A, p_indices=True)
>>> p
array([1, 3, 0, 2]) # as given by row indices above
>>> np.allclose(A, l[p, :] @ u)
True
We can also use nd-arrays, for example, a demonstration with 4D array:
>>> rng = np.random.default_rng()
>>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8])
>>> p, l, u = lu(A)
>>> p.shape, l.shape, u.shape
((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8))
>>> np.allclose(A, p @ l @ u)
True
>>> PL, U = lu(A, permute_l=True)
>>> np.allclose(A, PL @ U)
True
"""
a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
if a1.ndim < 2:
raise ValueError('The input array must be at least two-dimensional.')
# Also check if dtype is LAPACK compatible
if a1.dtype.char not in 'fdFD':
dtype_char = lapack_cast_dict[a1.dtype.char]
if not dtype_char: # No casting possible
raise TypeError(f'The dtype {a1.dtype} cannot be cast '
'to float(32, 64) or complex(64, 128).')
a1 = a1.astype(dtype_char[0]) # makes a copy, free to scratch
overwrite_a = True
*nd, m, n = a1.shape
k = min(m, n)
real_dchar = 'f' if a1.dtype.char in 'fF' else 'd'
# Empty input
if min(*a1.shape) == 0:
if permute_l:
PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
return PL, U
else:
P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else
np.empty([*nd, 0, 0], dtype=real_dchar))
L = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
return P, L, U
# Scalar case
if a1.shape[-2:] == (1, 1):
if permute_l:
return np.ones_like(a1), (a1 if overwrite_a else a1.copy())
else:
P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices
else np.ones_like(a1))
return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy())
# Then check overwrite permission
if not _datacopied(a1, a): # "a" still alive through "a1"
if not overwrite_a:
# Data belongs to "a" so make a copy
a1 = a1.copy(order='C')
# else: Do nothing we'll use "a" if possible
# else: a1 has its own data thus free to scratch
# Then layout checks, might happen that overwrite is allowed but original
# array was read-only or non-contiguous.
if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
a1 = a1.copy(order='C')
if not nd: # 2D array
p = np.empty(m, dtype=np.int32)
u = np.zeros([k, k], dtype=a1.dtype)
lu_dispatcher(a1, u, p, permute_l)
P, L, U = (p, a1, u) if m > n else (p, u, a1)
else: # Stacked array
# Prepare the contiguous data holders
P = np.empty([*nd, m], dtype=np.int32) # perm vecs
if m > n: # Tall arrays, U will be created
U = np.zeros([*nd, k, k], dtype=a1.dtype)
for ind in product(*[range(x) for x in a1.shape[:-2]]):
lu_dispatcher(a1[ind], U[ind], P[ind], permute_l)
L = a1
else: # Fat arrays, L will be created
L = np.zeros([*nd, k, k], dtype=a1.dtype)
for ind in product(*[range(x) for x in a1.shape[:-2]]):
lu_dispatcher(a1[ind], L[ind], P[ind], permute_l)
U = a1
# Convert permutation vecs to permutation arrays
# permute_l=False needed to enter here to avoid wasted efforts
if (not p_indices) and (not permute_l):
if nd:
Pa = np.zeros([*nd, m, m], dtype=real_dchar)
# An unreadable index hack - One-hot encoding for perm matrices
nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)]))
Pa[(*nd_ix, P)] = 1
P = Pa
else: # 2D case
Pa = np.zeros([m, m], dtype=real_dchar)
Pa[np.arange(m), P] = 1
P = Pa
return (L, U) if permute_l else (P, L, U)