Is there any way we can do something like merge in mergesort using numpy function?
some function like merge:
a = np.array([1,3,5])
b = np.array([2,4
When one array is considerably larger than the other, a decent speedup (5-fold on my pc) can be obtained by doing a np.searchorted, which is limited in speed primarily by searching insertion indices of the smaller array:
import numpy as np
def classic_merge(a, b):
c = np.concatenate((a,b))
c.sort(kind='mergesort')
return c
def new_merge(a, b):
if len(a) < len(b):
b, a = a, b
c = np.empty(len(a) + len(b), dtype=a.dtype)
b_indices = np.arange(len(b)) + np.searchsorted(a, b)
a_indices = np.ones(len(c), dtype=bool)
a_indices[b_indices] = False
c[b_indices] = b
c[a_indices] = a
return c
Timing gives:
from timeit import timeit as t
results = []
for size_digits in range(2, 8):
size = 10**size_digits
# size difference of a factor 10 here makes the difference!
a = np.arange(size // 10, dtype=np.int)
b = np.arange(size, dtype=np.int)
classic = t(lambda: classic_merge(a, b), number=10)
new = t(lambda: new_merge(a, b), number=10)
results.append((size_digits, classic, new))
if True:
text_format = " ".join(["{:<15}"] * 3)
print(text_format.format("log10(size)", "Classic", "New"))
table_format = " ".join(["{:.5f}"] * 3)
for result in results:
print(table_format.format(*result))
log10(size) Classic New
2.00000 0.00009 0.00027
3.00000 0.00021 0.00030
4.00000 0.00233 0.00082
5.00000 0.02827 0.00601
6.00000 0.33322 0.06059
7.00000 4.40571 0.86764
When a and b are roughly of equal length differences are smaller:
from timeit import timeit as t
results = []
for size_digits in range(2, 8):
size = 10**size_digits
# same size
a = np.arange(size , dtype=np.int)
b = np.arange(size, dtype=np.int)
classic = t(lambda: classic_merge(a, b), number=10)
new = t(lambda: new_merge(a, b), number=10)
results.append((size_digits, classic, new))
if True:
text_format = " ".join(["{:<15}"] * 3)
print(text_format.format("log10(size)", "Classic", "New"))
table_format = " ".join(["{:.5f}"] * 3)
for result in results:
print(table_format.format(*result))
log10(size) Classic New
2.00000 0.00026 0.00087
3.00000 0.00108 0.00182
4.00000 0.01257 0.01243
5.00000 0.16333 0.12692
6.00000 1.05006 0.49186
7.00000 8.35967 5.93732