numpy merge sorted array to an new array?

前端 未结 3 723
有刺的猬
有刺的猬 2020-12-11 14:58

Is there any way we can do something like merge in mergesort using numpy function?

some function like merge:

a = np.array([1,3,5])
b = np.array([2,4         


        
3条回答
  •  旧时难觅i
    2020-12-11 15:41

    When one array is considerably larger than the other, a decent speedup (5-fold on my pc) can be obtained by doing a np.searchorted, which is limited in speed primarily by searching insertion indices of the smaller array:

    import numpy as np
    
    def classic_merge(a, b):
        c = np.concatenate((a,b))
        c.sort(kind='mergesort')
        return c
    
    def new_merge(a, b):
        if len(a) < len(b):
            b, a = a, b
        c = np.empty(len(a) + len(b), dtype=a.dtype)
        b_indices = np.arange(len(b)) + np.searchsorted(a, b)
        a_indices = np.ones(len(c), dtype=bool)
        a_indices[b_indices] = False
        c[b_indices] = b
        c[a_indices] = a
        return c
    

    Timing gives:

    from timeit import timeit as t
    
    results = []
    for size_digits in range(2, 8):
        size = 10**size_digits
        # size difference of a factor 10 here makes the difference!
        a = np.arange(size // 10, dtype=np.int)
        b = np.arange(size, dtype=np.int)
        classic = t(lambda: classic_merge(a, b), number=10)
        new = t(lambda: new_merge(a, b), number=10)
        results.append((size_digits, classic, new))
    
    if True:
        text_format = " ".join(["{:<15}"] * 3)
        print(text_format.format("log10(size)", "Classic", "New"))
        table_format = "         ".join(["{:.5f}"] * 3)
        for result in results:
            print(table_format.format(*result))
    
    log10(size)     Classic         New            
    2.00000         0.00009         0.00027
    3.00000         0.00021         0.00030
    4.00000         0.00233         0.00082
    5.00000         0.02827         0.00601
    6.00000         0.33322         0.06059
    7.00000         4.40571         0.86764
    

    When a and b are roughly of equal length differences are smaller:

    from timeit import timeit as t
    
    results = []
    for size_digits in range(2, 8):
        size = 10**size_digits
        # same size
        a = np.arange(size , dtype=np.int)
        b = np.arange(size, dtype=np.int)
        classic = t(lambda: classic_merge(a, b), number=10)
        new = t(lambda: new_merge(a, b), number=10)
        results.append((size_digits, classic, new))
    
    if True:
        text_format = " ".join(["{:<15}"] * 3)
        print(text_format.format("log10(size)", "Classic", "New"))
        table_format = "         ".join(["{:.5f}"] * 3)
        for result in results:
            print(table_format.format(*result))
    
    log10(size)     Classic         New            
    2.00000         0.00026         0.00087
    3.00000         0.00108         0.00182
    4.00000         0.01257         0.01243
    5.00000         0.16333         0.12692
    6.00000         1.05006         0.49186
    7.00000         8.35967         5.93732
    

提交回复
热议问题