Matplotlib imshow/matshow display values on plot

后端 未结 3 580
-上瘾入骨i
-上瘾入骨i 2020-12-05 03:31

I am trying to create a 10x10 grid using either imshow or matshow in Matplotlib. The function below takes a numpy array as input, and plots the gri

相关标签:
3条回答
  • 2020-12-05 03:53

    For your graph you should should try with pyplot.table:

    import matplotlib.pyplot as plt
    import numpy as np
    
    board = np.zeros((10, 10))
    board[0,0] = 1
    board[0,1] = -1
    board[0,2] = 1
    def visBoard(board):
        data = np.empty(board.shape,dtype=np.str)
        data[:,:] = ' '
        data[board==1.0] = 'X'
        data[board==-1.0] = 'O'
        plt.axis('off')
        size = np.ones(board.shape[0])/board.shape[0]
        plt.table(cellText=data,loc='center',colWidths=size,cellLoc='center',bbox=[0,0,1,1])
        plt.show()
    
    visBoard(board)
    
    0 讨论(0)
  • 2020-12-05 04:00

    Some elaboration on the code of @wflynny making it into a function that takes any matrix no matter what size and plots its values.

    import numpy as np
    import matplotlib.pyplot as plt
    
    cols = np.random.randint(low=1,high=30)
    rows = np.random.randint(low=1,high=30)
    X = np.random.rand(rows,cols)
    
    def plotMat(X):
        fig, ax = plt.subplots()
        #imshow portion
        ax.imshow(X, interpolation='nearest')
        #text portion
        diff = 1.
        min_val = 0.
        rows = X.shape[0]
        cols = X.shape[1]
        col_array = np.arange(min_val, cols, diff)
        row_array = np.arange(min_val, rows, diff)
        x, y = np.meshgrid(col_array, row_array)
        for col_val, row_val in zip(x.flatten(), y.flatten()):
            c = '+' if X[row_val.astype(int),col_val.astype(int)] < 0.5 else '-' 
            ax.text(col_val, row_val, c, va='center', ha='center')
        #set tick marks for grid
        ax.set_xticks(np.arange(min_val-diff/2, cols-diff/2))
        ax.set_yticks(np.arange(min_val-diff/2, rows-diff/2))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xlim(min_val-diff/2, cols-diff/2)
        ax.set_ylim(min_val-diff/2, rows-diff/2)
        ax.grid()
        plt.show()
    
    plotMat(X)
    
    0 讨论(0)
  • 2020-12-05 04:03

    Can you do something like:

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots()
    
    min_val, max_val = 0, 10
    ind_array = np.arange(min_val + 0.5, max_val + 0.5, 1.0)
    x, y = np.meshgrid(ind_array, ind_array)
    
    for i, (x_val, y_val) in enumerate(zip(x.flatten(), y.flatten())):
        c = 'x' if i%2 else 'o' 
        ax.text(x_val, y_val, c, va='center', ha='center')
    #alternatively, you could do something like
    #for x_val, y_val in zip(x.flatten(), y.flatten()):
    #    c = 'x' if (x_val + y_val)%2 else 'o'
    
    ax.set_xlim(min_val, max_val)
    ax.set_ylim(min_val, max_val)
    ax.set_xticks(np.arange(max_val))
    ax.set_yticks(np.arange(max_val))
    ax.grid()
    

    enter image description here


    Edit:

    Here is an updated example with an imshow background.

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots()
    
    min_val, max_val, diff = 0., 10., 1.
    
    #imshow portion
    N_points = (max_val - min_val) / diff
    imshow_data = np.random.rand(N_points, N_points)
    ax.imshow(imshow_data, interpolation='nearest')
    
    #text portion
    ind_array = np.arange(min_val, max_val, diff)
    x, y = np.meshgrid(ind_array, ind_array)
    
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = 'x' if (x_val + y_val)%2 else 'o'
        ax.text(x_val, y_val, c, va='center', ha='center')
    
    #set tick marks for grid
    ax.set_xticks(np.arange(min_val-diff/2, max_val-diff/2))
    ax.set_yticks(np.arange(min_val-diff/2, max_val-diff/2))
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_xlim(min_val-diff/2, max_val-diff/2)
    ax.set_ylim(min_val-diff/2, max_val-diff/2)
    ax.grid()
    plt.show()
    

    enter image description here

    0 讨论(0)
提交回复
热议问题