decomp_lu.py
6.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""LU decomposition functions."""
from warnings import warn
from numpy import asarray, asarray_chkfinite
# Local imports
from .misc import _datacopied, LinAlgWarning
from .lapack import get_lapack_funcs
from .flinalg import get_flinalg_funcs
__all__ = ['lu', 'lu_solve', 'lu_factor']
def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : (M, M) array_like
Matrix to decompose
overwrite_a : bool, optional
Whether to overwrite data in A (may increase performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
lu : (N, N) ndarray
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv : (N,) ndarray
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
See also
--------
lu_solve : solve an equation system using the LU factorization of a matrix
Notes
-----
This is a wrapper to the ``*GETRF`` routines from LAPACK.
Examples
--------
>>> from scipy.linalg import lu_factor
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> lu, piv = lu_factor(A)
>>> piv
array([2, 2, 3, 3], dtype=int32)
Convert LAPACK's ``piv`` array to NumPy index and test the permutation
>>> piv_py = [2, 0, 3, 1]
>>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
>>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4)))
True
"""
if check_finite:
a1 = asarray_chkfinite(a)
else:
a1 = asarray(a)
if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
raise ValueError('expected square matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
getrf, = get_lapack_funcs(('getrf',), (a1,))
lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %dth argument of '
'internal getrf (lu_factor)' % -info)
if info > 0:
warn("Diagonal number %d is exactly zero. Singular matrix." % info,
LinAlgWarning, stacklevel=2)
return lu, piv
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""Solve an equation system, a x = b, given the LU factorization of a
Parameters
----------
(lu, piv)
Factorization of the coefficient matrix a, as given by lu_factor
b : array
Right-hand side
trans : {0, 1, 2}, optional
Type of system to solve:
===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
overwrite_b : bool, optional
Whether to overwrite data in b (may increase performance)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
x : array
Solution to the system
See also
--------
lu_factor : LU factorize a matrix
Examples
--------
>>> from scipy.linalg import lu_factor, lu_solve
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> b = np.array([1, 1, 1, 1])
>>> lu, piv = lu_factor(A)
>>> x = lu_solve((lu, piv), b)
>>> np.allclose(A @ x - b, np.zeros((4,)))
True
"""
(lu, piv) = lu_and_piv
if check_finite:
b1 = asarray_chkfinite(b)
else:
b1 = asarray(b)
overwrite_b = overwrite_b or _datacopied(b1, b)
if lu.shape[0] != b1.shape[0]:
raise ValueError("incompatible dimensions.")
getrs, = get_lapack_funcs(('getrs',), (lu, b1))
x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
if info == 0:
return x
raise ValueError('illegal value in %dth argument of internal gesv|posv'
% -info)
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : (M, N) array_like
Array to decompose
permute_l : bool, optional
Perform the multiplication P*L (Default: do not permute)
overwrite_a : bool, optional
Whether to overwrite data in a (may improve performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
**(If permute_l == False)**
p : (M, M) ndarray
Permutation matrix
l : (M, K) ndarray
Lower triangular or trapezoidal matrix with unit diagonal.
K = min(M, N)
u : (K, N) ndarray
Upper triangular or trapezoidal matrix
**(If permute_l == True)**
pl : (M, K) ndarray
Permuted L matrix.
K = min(M, N)
u : (K, N) ndarray
Upper triangular or trapezoidal matrix
Notes
-----
This is a LU factorization routine written for SciPy.
Examples
--------
>>> from scipy.linalg import lu
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> p, l, u = lu(A)
>>> np.allclose(A - p @ l @ u, np.zeros((4, 4)))
True
"""
if check_finite:
a1 = asarray_chkfinite(a)
else:
a1 = asarray(a)
if len(a1.shape) != 2:
raise ValueError('expected matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
flu, = get_flinalg_funcs(('lu',), (a1,))
p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %dth argument of '
'internal lu.getrf' % -info)
if permute_l:
return l, u
return p, l, u