Efficiently multiply elements of each row together

岁酱吖の 提交于 2019-12-28 07:04:21

问题


Given a ndarray of size (n, 3) with n around 1000, how to multiply together all elements for each row, fast? The (inelegant) second solution below runs in about 0.3 millisecond, can it be improved?

# dummy data
n = 999
a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)

# two solutions
def prod1(array):
    return [np.prod(row) for row in array]

def prod2(array):
    return [row[0]*row[1]*row[2] for row in array]

# benchmark
start = time.time()
prod1(a)
print time.time() - start
# 0.0015

start = time.time()
prod2(a)
print time.time() - start
# 0.0003

回答1:


Improving performance further

At first a general rule of thumb. You are working with numerical arrays, so use arrays and not lists. Lists may look somewhat like a general array, but beeing completely different in the backend and absolutely not suteable for most numerical calculations.

If you write a simple code using Numpy-Arrays you can gain performance by simply jitting it as shown beyond. If you use lists you can more or less rewrite your code.

import numpy as np
import numba as nb

@nb.njit(fastmath=True)
def prod(array):
  assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
  res=np.empty(array.shape[0],dtype=array.dtype)
  for i in range(array.shape[0]):
    res[i]=array[i,0]*array[i,1]*array[i,2]

  return res

Using np.prod(a, axis=1) isn't a bad idea, but the performance isn't really good. For an array with only 1000x3 the function call overhead is quite significant. This can be completely avoided, when using the jitted prod function in another jitted function.

Benchmarks

# The first call to the jitted function takes about 200ms compilation overhead. 
#If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
n=999
prod1   = 795  µs
prod2   = 187  µs
np.prod = 7.42 µs
prod      0.85 µs

n=9990
prod1   = 7863 µs
prod2   = 1810 µs
np.prod = 50.5 µs
prod      2.96 µs



回答2:


np.prod accepts an axis argument:

np.prod(a, axis=1)

With axis=1, the column-wise product is computed for each row.

Sanity check

assert np.array_equal(np.prod(a, axis=1), prod1(a))

Performance

17.6 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

(1000x speedup)



来源:https://stackoverflow.com/questions/49290059/efficiently-multiply-elements-of-each-row-together

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!