Markov chain stationary distributions with scipy.sparse?

前端 未结 4 1703
挽巷
挽巷 2020-12-16 07:32

I have a Markov chain given as a large sparse scipy matrix A. (I\'ve constructed the matrix in scipy.sparse.dok_matrix format, but con

4条回答
  •  清酒与你
    2020-12-16 07:56

    Several articles with summaries of possible approaches can be found with Google scholar, here's one: http://www.ima.umn.edu/preprints/pp1992/932.pdf

    What's done below is a combination of the suggestion by @Helge Dietert above on splitting to strongly connected components first, and approach #4 in the paper linked above.

    import numpy as np
    import time
    
    # NB. Scipy >= 0.14.0 probably required
    import scipy
    from scipy.sparse.linalg import gmres, spsolve
    from scipy.sparse import csgraph
    from scipy import sparse 
    
    
    def markov_stationary_components(P, tol=1e-12):
        """
        Split the chain first to connected components, and solve the
        stationary state for the smallest one
        """
        n = P.shape[0]
    
        # 0. Drop zero edges
        P = P.tocsr()
        P.eliminate_zeros()
    
        # 1. Separate to connected components
        n_components, labels = csgraph.connected_components(P, directed=True, connection='strong')
    
        # The labels also contain decaying components that need to be skipped
        index_sets = []
        for j in range(n_components):
            indices = np.flatnonzero(labels == j)
            other_indices = np.flatnonzero(labels != j)
    
            Px = P[indices,:][:,other_indices]
            if Px.max() == 0:
                index_sets.append(indices)
        n_components = len(index_sets)
    
        # 2. Pick the smallest one
        sizes = [indices.size for indices in index_sets]
        min_j = np.argmin(sizes)
        indices = index_sets[min_j]
    
        print("Solving for component {0}/{1} of size {2}".format(min_j, n_components, indices.size))
    
        # 3. Solve stationary state for it
        p = np.zeros(n)
        if indices.size == 1:
            # Simple case
            p[indices] = 1
        else:
            p[indices] = markov_stationary_one(P[indices,:][:,indices], tol=tol)
    
        return p
    
    
    def markov_stationary_one(P, tol=1e-12, direct=False):
        """
        Solve stationary state of Markov chain by replacing the first
        equation by the normalization condition.
        """
        if P.shape == (1, 1):
            return np.array([1.0])
    
        n = P.shape[0]
        dP = P - sparse.eye(n)
        A = sparse.vstack([np.ones(n), dP.T[1:,:]])
        rhs = np.zeros((n,))
        rhs[0] = 1
    
        if direct:
            # Requires that the solution is unique
            return spsolve(A, rhs)
        else:
            # GMRES does not care whether the solution is unique or not, it
            # will pick the first one it finds in the Krylov subspace
            p, info = gmres(A, rhs, tol=tol)
            if info != 0:
                raise RuntimeError("gmres didn't converge")
            return p
    
    
    def main():
        # Random transition matrix (connected)
        n = 100000
        np.random.seed(1234)
        P = sparse.rand(n, n, 1e-3) + sparse.eye(n)
        P = P + sparse.diags([1, 1], [-1, 1], shape=P.shape)
    
        # Disconnect several components
        P = P.tolil()
        P[:1000,1000:] = 0
        P[1000:,:1000] = 0
    
        P[10000:11000,:10000] = 0
        P[10000:11000,11000:] = 0
        P[:10000,10000:11000] = 0
        P[11000:,10000:11000] = 0
    
        # Normalize
        P = P.tocsr()
        P = P.multiply(sparse.csr_matrix(1/P.sum(1).A))
    
        print("*** Case 1")
        doit(P)
    
        print("*** Case 2")
        P = sparse.csr_matrix(np.array([[1.0, 0.0, 0.0, 0.0],
                                        [0.5, 0.5, 0.0, 0.0],
                                        [0.0, 0.0, 0.5, 0.5],
                                        [0.0, 0.0, 0.5, 0.5]]))
        doit(P)
    
    def doit(P):
        assert isinstance(P, sparse.csr_matrix)
        assert np.isfinite(P.data).all()
    
        print("Construction finished!")
    
        def check_solution(method):
            print("\n\n-- {0}".format(method.__name__))
            start = time.time()
            p = method(P)
            print("time: {0}".format(time.time() - start))
            print("error: {0}".format(np.linalg.norm(P.T.dot(p) - p)))
            print("min(p)/max(p): {0}, {1}".format(p.min(), p.max()))
            print("sum(p): {0}".format(p.sum()))
    
        check_solution(markov_stationary_components)
    
    
    if __name__ == "__main__":
        main()
    

    EDIT: spotted a bug --- csgraph.connected_components returns also purely decaying components, which need to be filtered out.

提交回复
热议问题