问题
I am trying to undersand how scipy CSR works.
https://docs.scipy.org/doc/scipy/reference/sparse.html
For example, of the following matrix on https://en.wikipedia.org/wiki/Sparse_matrix
( 0 0 0 0 )
( 5 8 0 0 )
( 0 0 3 0 )
( 0 6 0 0 )
it says the CSR representation is the following.
Must V list one row after another with non-zero elements in a row list from left to right?
I can understand COL_INDEX
is the column index (column 1 is indexed as 0) corresponding to elements in V.
I don't understand ROW_INDEX
. Could anybody show me how the ROW_INDEX
was created from the original matrix? Thanks.
V = [ 5 8 3 6 ]
COL_INDEX = [ 0 1 2 1 ]
ROW_INDEX = [ 0 0 2 3 4 ]
回答1:
coo
format
I think it's best to start with the coo
definition. It's easier to understand, and widely used:
In [90]: A = np.array([[0,0,0,0],[5,8,0,0],[0,0,3,0],[0,6,0,0]])
In [91]: M = sparse.coo_matrix(A)
The values are stored in 3 attributes:
In [92]: M.row
Out[92]: array([1, 1, 2, 3], dtype=int32)
In [93]: M.col
Out[93]: array([0, 1, 2, 1], dtype=int32)
In [94]: M.data
Out[94]: array([5, 8, 3, 6])
We can make a new matrix from those 3 arrays:
In [95]: sparse.coo_matrix((_94, (_92, _93))).A
Out[95]:
array([[0, 0, 0],
[5, 8, 0],
[0, 0, 3],
[0, 6, 0]])
oops, I need to add a shape, since one column is all 0s:
In [96]: sparse.coo_matrix((_94, (_92, _93)), shape=(4,4)).A
Out[96]:
array([[0, 0, 0, 0],
[5, 8, 0, 0],
[0, 0, 3, 0],
[0, 6, 0, 0]])
Another way to display this matrix:
In [97]: print(M)
(1, 0) 5
(1, 1) 8
(2, 2) 3
(3, 1) 6
np.where(A)
gives the same non-zero coordinates.
In [108]: np.where(A)
Out[108]: (array([1, 1, 2, 3]), array([0, 1, 2, 1]))
conversion to csr
Once we have coo
, we can easily convert it to csr
. In fact sparse
often does that for us:
In [98]: Mr = M.tocsr()
In [99]: Mr.data
Out[99]: array([5, 8, 3, 6], dtype=int64)
In [100]: Mr.indices
Out[100]: array([0, 1, 2, 1], dtype=int32)
In [101]: Mr.indptr
Out[101]: array([0, 0, 2, 3, 4], dtype=int32)
Sparse does several things - it sorts the indices, sums duplicates, and replaces the row
with a indptr
array. Here it is actually longer than the original, but in general it will be shorter, since it has just one value per row (plus 1). But perhaps more important, most of the fast calculation routines, especially matrix multiplication, have been written using the csr
format.
I've used this package a lot. MATLAB as well, where the default definition is in the coo
style, but the internal storage is csc
(but not as exposed to users as in scipy
). But I've never tried to derive indptr
from scratch. I could, but I don't need to.
csr_matrix
accepts inputs in the coo
format, but also in the indptr
etc format. I wouldn't recommend it, unless you already have those inputs calculated (say from another matrix). It's more error prone, and probably not much faster.
Iteration with indptr
However sometimes it is useful to iterate on intptr
, and perform calculations directly on the data
. Often this is faster than working with the provided methods.
For example we can list the nonzero values by row:
In [104]: for i in range(Mr.shape[0]):
...: pt = slice(Mr.indptr[i], Mr.indptr[i+1])
...: print(i, Mr.indices[pt], Mr.data[pt])
...:
0 [] []
1 [0 1] [5 8]
2 [2] [3]
3 [1] [6]
Keeping the initial 0
makes this iteration easier. When the matrix is (10000,90000) there's not much incentive to reduces the size of indptr
by 1.
lil
format
The lil
format stores the matrix in a similar manner:
In [105]: Ml = M.tolil()
In [106]: Ml.data
Out[106]: array([list([]), list([5, 8]), list([3]), list([6])], dtype=object)
In [107]: Ml.rows
Out[107]: array([list([]), list([0, 1]), list([2]), list([1])], dtype=object)
In [110]: for i,(r,d) in enumerate(zip(Ml.rows, Ml.data)):
...: print(i, r, d)
...:
0 [] []
1 [0, 1] [5, 8]
2 [2] [3]
3 [1] [6]
Because of how rows are stored, lil
actually allows us to fetch a view
:
In [167]: Ml.getrowview(2)
Out[167]:
<1x4 sparse matrix of type '<class 'numpy.longlong'>'
with 1 stored elements in List of Lists format>
In [168]: for i in range(Ml.shape[0]):
...: print(Ml.getrowview(i))
...:
(0, 0) 5
(0, 1) 8
(0, 2) 3
(0, 1) 6
回答2:
From the scipy manual:
csr_matrix((data, indices, indptr), [shape=(M, N)]) is the standard CSR representation where the column indices for row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored in data[indptr[i]:indptr[i+1]]. If the shape parameter is not supplied, the matrix dimensions are inferred from the index arrays.
indptr
is the same as ROW_INDEX
and indicies
is the same as COL_INDEX
.
Here is an example of a naive way to create the indices and value array. Essentially ROW_INDICES[i + 1] is the total number of non-zero entires from row 0 to i inclusive with the last entry being the total number of non-zero entries.
ROW_INDICES = [0]
COL_INDICES = []
VALS = []
for i in range(num_rows):
ROW_INDICES.append(ROW_INDICES[i])
for j in range(num_cols):
if m[i, j] > 0:
ROW_INDICES[i + 1] += 1
COL_INDICES.append(j)
VALS.append(m[i, j])
来源:https://stackoverflow.com/questions/59959379/understand-the-csr-format