Finding the point of intersection of two line graphs drawn in matplotlib

后端 未结 2 1591
自闭症患者
自闭症患者 2020-12-19 20:13

Is there a way to find the point of intersection of two line graphs in matplotlib?

Consider the code

import numpy as np
import matplotlib.pyplot as          


        
相关标签:
2条回答
  • 2020-12-19 20:56

    I've expanded @SparkAndShine's solution to work with 3D data, as well as did some performance enhancements using a KD-tree. Full solution is posted here: https://stackoverflow.com/a/51145981/4212158

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from scipy.spatial import cKDTree
    from scipy import interpolate
    
    fig = plt.figure()
    ax = fig.add_axes([0, 0, 1, 1], projection='3d')
    ax.axis('off')
    
    def upsample_coords(coord_list):
        # s is smoothness, set to zero
        # k is degree of the spline. setting to 1 for linear spline
        tck, u = interpolate.splprep(coord_list, k=1, s=0.0)
        upsampled_coords = interpolate.splev(np.linspace(0, 1, 100), tck)
        return upsampled_coords
    
    # target line
    x_targ = [1, 2, 3, 4, 5, 6, 7, 8]
    y_targ = [20, 100, 50, 120, 55, 240, 50, 25]
    z_targ = [20, 100, 50, 120, 55, 240, 50, 25]
    targ_upsampled = upsample_coords([x_targ, y_targ, z_targ])
    targ_coords = np.column_stack(targ_upsampled)
    
    # KD-tree for nearest neighbor search
    targ_kdtree = cKDTree(targ_coords)
    
    # line two
    x2 = [3,4,5,6,7,8,9]
    y2 = [25,35,14,67,88,44,120]
    z2 = [25,35,14,67,88,44,120]
    l2_upsampled = upsample_coords([x2, y2, z2])
    l2_coords = np.column_stack(l2_upsampled)
    
    # plot both lines
    ax.plot(x_targ, y_targ, z_targ, color='black', linewidth=0.5)
    ax.plot(x2, y2, z2, color='darkgreen', linewidth=0.5)
    
    # find intersections
    for i in range(len(l2_coords)):
        if i == 0:  # skip first, there is no previous point
            continue
    
        distance, close_index = targ_kdtree.query(l2_coords[i], distance_upper_bound=.5)
    
        # strangely, points infinitely far away are somehow within the upper bound
        if np.isinf(distance):
            continue
    
        # plot ground truth that was activated
        _x, _y, _z = targ_kdtree.data[close_index]
        ax.scatter(_x, _y, _z, 'gx')
        _x2, _y2, _z2 = l2_coords[i]
        ax.scatter(_x2, _y2, _z2, 'rx')  # Plot the cross point
    
    
    plt.show()
    
    0 讨论(0)
  • 2020-12-19 21:08

    Here is an ugly solution (an improved version is at the bottom). After plotting, we know that two line graphs make a cross at the range of (6, 7)

    Now, we plot this cross point with the following source code,

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    x1 = [1,2,3,4,5,6,7,8]
    y1 = [20,100,50,120,55,240,50,25]
    x2 = [3,4,5,6,7,8,9]
    y2 = [25,35,14,67,88,44,120]
    
    ax.plot(x1, y1, color='lightblue',linewidth=3)
    ax.plot(x2, y2, color='darkgreen', marker='^')
    
    
    # Plot the cross point
    
    x3 = np.linspace(6, 7, 1000)        # (6, 7) intersection range
    y1_new = np.linspace(240, 50, 1000) # (6, 7) corresponding to (240, 50) in y1
    y2_new = np.linspace(67, 88, 1000)  # (6, 7) corresponding to (67, 88) in y2
    
    idx = np.argwhere(np.isclose(y1_new, y2_new, atol=0.1)).reshape(-1)
    ax.plot(x3[idx], y2_new[idx], 'ro')
    
    plt.show()
    

    The end user would not be happy to input the cross range manually. Here is an improved version by looping over every two segments, but it might be a time consumer.

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    x1 = [1,2,3,4,5,6,7,8]
    y1 = [20,100,50,120,55,240,50,25]
    x2 = [3,4,5,6,7,8,9]
    y2 = [25,35,14,67,88,44,120]
    
    ax.plot(x1, y1, color='lightblue',linewidth=3)
    ax.plot(x2, y2, color='darkgreen', marker='^')
    
    # Get the common range, from `max(x1[0], x2[0])` to `min(x1[-1], x2[-1])`   
    x_begin = max(x1[0], x2[0])     # 3
    x_end = min(x1[-1], x2[-1])     # 8
    
    points1 = [t for t in zip(x1, y1) if x_begin<=t[0]<=x_end]  # [(3, 50), (4, 120), (5, 55), (6, 240), (7, 50), (8, 25)]
    points2 = [t for t in zip(x2, y2) if x_begin<=t[0]<=x_end]  # [(3, 25), (4, 35), (5, 14), (6, 67), (7, 88), (8, 44)]
    
    idx = 0
    nrof_points = len(points1)
    while idx < nrof_points-1:
        # Iterate over two line segments
        y_min = min(points1[idx][1], points1[idx+1][1]) 
        y_max = max(points1[idx+1][1], points2[idx+1][1]) 
    
        x3 = np.linspace(points1[idx][0], points1[idx+1][0], 1000)      # e.g., (6, 7) intersection range
        y1_new = np.linspace(points1[idx][1], points1[idx+1][1], 1000)  # e.g., (6, 7) corresponding to (240, 50) in y1
        y2_new = np.linspace(points2[idx][1], points2[idx+1][1], 1000)  # e.g., (6, 7) corresponding to (67, 88) in y2
    
        tmp_idx = np.argwhere(np.isclose(y1_new, y2_new, atol=0.1)).reshape(-1)
        if tmp_idx:
            ax.plot(x3[tmp_idx], y2_new[tmp_idx], 'ro')                 # Plot the cross point
    
        idx += 1
    
    plt.show()
    
    0 讨论(0)
提交回复
热议问题