Vectorized pythonic way to get count of elements greater than current element

半腔热情 提交于 2020-01-01 11:35:12

问题


I'd like to compare every entry in array b with its respective column to find how many entries (from that column) are larger than the reference. My code seems embarrassingly slow and I suspect it is due to for loops rather than vector operations.

Can we 'vectorize' the following code?

import numpy as np

n = 1000
m = 200
b = np.random.rand(n,m)
p = np.zeros((n,m))

for i in range(0,n): #rows
   for j in range(0,m):  # cols
     r = b[i,j]  
     p[i,j] = ( ( b[:,j] > r).sum() ) / (n) 

After some more thought, I think sorting each column would improve overall runtime by making the later comparisons much faster.

After some more searching I believe I want column-wise percentileofscore (http://lagrange.univ-lyon1.fr/docs/scipy/0.17.1/generated/scipy.stats.percentileofscore.html)


回答1:


It just needed a bit of deeper study to figure out that we could simply use argsort() indices along each column to get the count of greater than the current element at each iteration.

Approach #1

With that theory in mind, one solution would be simply using two argsort-ing to get the counts -

p = len(b)-1-b.argsort(0).argsort(0)

Approach #2

We could optimize it further, given the fact that the argsort indices are unique numbers. So, the second argsort step could use some array-assignment + advanced-indexing, like so -

def count_occ(b):
    n,m = b.shape     
    out = np.zeros((n,m),dtype=int)
    idx = b.argsort(0)
    out[idx, np.arange(m)] = n-1-np.arange(n)[:,None]
    return out

Finally, divide by n as noted in the question for both the approaches.


Benchmarking

Timing all the approaches posted thus far -

In [174]: np.random.seed(0)
     ...: n = 1000
     ...: m = 200
     ...: b = np.random.rand(n,m)

In [175]: %timeit (len(b)-1-b.argsort(0).argsort(0))/float(n)
100 loops, best of 3: 17.6 ms per loop

In [176]: %timeit count_occ(b)/float(n)
100 loops, best of 3: 12.1 ms per loop

# @Brad Solomon's soln
In [177]: %timeit np.sum(b > b[:, None], axis=-2) / float(n)
1 loop, best of 3: 349 ms per loop

# @marco romelli's loopy soln
In [178]: %timeit loopy_soln(b)/float(n)
10 loops, best of 3: 139 ms per loop



回答2:


I think the following solution is much faster:

import numpy as np

n = 1000
m = 200
b = np.random.rand(n,m)
p = np.zeros((n,m))

for i in range(m):
    col = []
    for j in range(n):
        col.append((j, b[j,i]))
    a = sorted(col, key=lambda x: x[1], reverse=True)
    for j in range(n):
        p[a[j][0], i] = j

It works column by column and is based on sorting, so with a fast guess I would say O(nmlogn).

EDIT

Benchmark results

Original code: 1.46 s ± 8.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

This answer: 178 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)




回答3:


There will be some added space complexity, but one way to get to this would be to expand b to enable a broadcasted comparison and enable you to get rid of the (Python) loop entirely:

# for n = 10; m = 2; np.random.seed(444)
>>> np.sum(b > b[:, None], axis=-2) / n
array([[0.7, 0.1],
       [0.3, 0.6],
       [0. , 0.2],
       [0.4, 0.8],
       # ...

Full code:

import numpy as np

np.random.seed(444)

def loop(b):
    # Could also use for (i, j), val in np.ndenumerate(b)
    p = np.zeros_like(b)
    for i in range(0, n):
        for j in range(0, m):
            r = b[i, j]
            p[i, j] = ((b[:, j] > r).sum())
    return p / n

def noloop(b):
    n = b.shape[0]
    return np.sum(b > b[:, None], axis=-2) / n

n = 10
m = 2
b = np.random.rand(n, m)

assert np.allclose(loop(b), noloop(b))
# True

n = 1000
m = 200
b = np.random.rand(n, m)

%timeit loop(b)
# 1.59 s ± 50.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -r 7 -n 1 noloop(b)
# 443 ms ± 18.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

assert np.allclose(loop(b), noloop(b))
# True


来源:https://stackoverflow.com/questions/49880927/vectorized-pythonic-way-to-get-count-of-elements-greater-than-current-element

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!