Set row of csr_matrix

前端 未结 5 2064
旧巷少年郎
旧巷少年郎 2020-12-18 12:04

I have a sparse csr_matrix, and I want to change the values of a single row to different values. I can\'t find an easy and efficient implementation however. This is what it

5条回答
  •  温柔的废话
    2020-12-18 12:52

    In physicalattraction's answer, the len(new_row) must be equal to A.shape[1] what may not be interesting when adding sparse rows.

    So, based on his answer I've came up with a method to set rows in csr while it keeps the sparcity property. Additionally I've added a method to convert dense arrays to sparse arrays (on data, indices format)

    def to_sparse(dense_arr):
        sparse = [(data, index) for index, data in enumerate(dense_arr) if data != 0]
    
        # Convert list of tuples to lists
        sparse = list(map(list, zip(*sparse)))
    
        # Return data and indices
        return sparse[0], sparse[1]
    
    def set_row_csr_unbounded(A, row_idx, new_row_data, new_row_indices):
        '''
        Replace a row in a CSR sparse matrix A.
    
        Parameters
        ----------
        A: csr_matrix
            Matrix to change
        row_idx: int
            index of the row to be changed
        new_row_data: np.array
            list of new values for the row of A
        new_row_indices: np.array
            list of indices for new row
    
        Returns
        -------
        None (the matrix A is changed in place)
    
        Prerequisites
        -------------
        The row index shall be smaller than the number of rows in A
        Row data and row indices must have the same size
        '''
        assert isspmatrix_csr(A), 'A shall be a csr_matrix'
        assert row_idx < A.shape[0], \
                'The row index ({0}) shall be smaller than the number of rows in A ({1})' \
                .format(row_idx, A.shape[0])
    
        try:
            N_elements_new_row = len(new_row_data)
        except TypeError:
            msg = 'Argument new_row_data shall be a list or numpy array, is now a {0}'\
            .format(type(new_row_data))
            raise AssertionError(msg)
    
        try:
            assert N_elements_new_row == len(new_row_indices), \
                    'new_row_data and new_row_indices must have the same size'
        except TypeError:
            msg = 'Argument new_row_indices shall be a list or numpy array, is now a {0}'\
            .format(type(new_row_indices))
            raise AssertionError(msg)
    
        idx_start_row = A.indptr[row_idx]
        idx_end_row = A.indptr[row_idx + 1]
    
        A.data = np.r_[A.data[:idx_start_row], new_row_data, A.data[idx_end_row:]]
        A.indices = np.r_[A.indices[:idx_start_row], new_row_indices, A.indices[idx_end_row:]]
        A.indptr = np.r_[A.indptr[:row_idx + 1], A.indptr[(row_idx + 1):] + N_elements_new_row]
    

提交回复
热议问题