Counting inversions in an array

前端 未结 30 2443
死守一世寂寞
死守一世寂寞 2020-11-22 04:14

I\'m designing an algorithm to do the following: Given array A[1... n], for every i < j, find all inversion pairs such that A[i] > A[j]

30条回答
  •  一个人的身影
    2020-11-22 04:38

    The number of inversions can be found by analyzing the merge process in merge sort : merge process

    When copying a element from the second array to the merge array (the 9 in this exemple), it keeps its place relatively to other elements. When copying a element from the first array to the merge array (the 5 here) it is inverted with all the elements staying in the second array (2 inversions with the 3 and the 4). So a little modification of merge sort can solve the problem in O(n ln n).
    For exemple, just uncomment the two # lines in the mergesort python code below to have the count.

    def merge(l1,l2):
        l = []
        # global count
        while l1 and l2:
            if l1[-1] <= l2[-1]:
                l.append(l2.pop())
            else:
                l.append(l1.pop())
                # count += len(l2)
        l.reverse()
        return l1 + l2 + l
    
    def sort(l): 
        t = len(l) // 2
        return merge(sort(l[:t]), sort(l[t:])) if t > 0 else l
    
    count=0
    print(sort([5,1,2,4,9,3]), count)
    # [1, 2, 3, 4, 5, 9] 6
    

    EDIT 1

    The same task can be achieved with a stable version of quick sort, known to be slightly faster :

    def part(l):
        pivot=l[-1]
        small,big = [],[]
        count = big_count = 0
        for x in l:
            if x <= pivot:
                small.append(x)
                count += big_count
            else:
                big.append(x)
                big_count += 1
        return count,small,big
    
    def quick_count(l):
        if len(l)<2 : return 0
        count,small,big = part(l)
        small.pop()
        return count + quick_count(small) + quick_count(big)
    

    Choosing pivot as the last element, inversions are well counted, and execution time 40% better than merge one above.

    EDIT 2

    For performance in python, a numpy & numba version :

    First the numpy part, which use argsort O (n ln n) :

    def count_inversions(a):
        n = a.size
        counts = np.arange(n) & -np.arange(n)  # The BIT
        ags = a.argsort(kind='mergesort')    
        return  BIT(ags,counts,n)
    

    And the numba part for the efficient BIT approach :

    @numba.njit
    def BIT(ags,counts,n):
        res = 0        
        for x in ags :
            i = x
            while i:
                res += counts[i]
                i -= i & -i
            i = x+1
            while i < n:
                counts[i] -= 1
                i += i & -i
        return  res  
    

提交回复
热议问题