Count the number of non zero values in a numpy array in Numba

China☆狼群 提交于 2019-12-10 22:08:06

问题


Very simple. I am trying to count the number of non-zero values in an array in NumPy jit compiled with Numba (njit()). The following I've tried is not allowed by Numba.

  1. a[a != 0].size
  2. np.count_nonzero(a)
  3. len(a[a != 0])
  4. len(a) - len(a[a == 0])

I don't want to use for loops if there is still a faster, more pythonic and elegant way.

For that commenter that wanted to see a full code example...

import numpy as np
from numba import njit

@njit()
def n_nonzero(a):
    return a[a != 0].size

回答1:


You may also consider, well, counting the nonzero values:

import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

I know it seems wrong, but bear with me:

import numpy as np
import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

@nb.njit()
def count_len_nonzero(a):
    return len(np.nonzero(a)[0])

@nb.njit()
def count_sum_neq_zero(a):
    return (a != 0).sum()

np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c

%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It is in fact faster than np.count_nonzero, which can get quite slow for some reason:

%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)



回答2:


In case you need it really fast for large arrays you could even use numbas prange to process the count in parallel (for small arrays it will be slower due to the parallel-processing overhead).

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

Note that when you use numba you normally want to write out your loops because that's what numba is really very good at optimizing.

I actually timed it against the other solutions mentioned here (using my Python module simple_benchmark):

Code to reproduce:

import numpy as np
from numba import njit, prange

@njit
def n_nonzero(a):
    return a[a != 0].size

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

@njit() 
def methodB(a): 
    return (a!=0).sum()

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

@njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

from simple_benchmark import benchmark

args = {}
for exp in range(2, 20):
    size = 2**exp
    arr = np.random.random(size)
    arr[arr < 0.3] = 0.0
    args[size] = arr

b = benchmark(
    funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
    arguments=args,
    argument_name='array size',
    warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)



回答3:


You can use np.nonzero and induce the length of it:

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

count_non_zero(np.array([0,1,0,1]))
# 2



回答4:


Not sure if I have made a mistake here, but this seems 6x faster:

# Make something worth checking
a=np.random.randint(0,3,1000000000,dtype=np.uint8)  

In [41]: @njit() 
    ...: def methodA(a): 
    ...:     return len(np.nonzero(a)[0])                                                                                           

# Call and check result
In [42]: methodA(a)                                                                                 
Out[42]: 666644445

In [43]: %timeit methodA(a)                                                                         
4.65 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [44]: @njit() 
    ...: def methodB(a): 
    ...:     return (a!=0).sum()                                                                                         

# Call and check result    
In [45]: methodB(a)                                                                                 
Out[45]: 666644445

In [46]: %timeit methodB(a)                                                                         
724 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


来源:https://stackoverflow.com/questions/54830176/count-the-number-of-non-zero-values-in-a-numpy-array-in-numba

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!