How do I pass in calculated values to a list sort using numba.jit in python?

本秂侑毒 提交于 2021-02-08 16:57:24

问题


I am trying to sort a list using a custom key within a numba-jit function in Python. Simple custom keys work, for example I know that I can just sort by the absolute value using something like this:

import numba

@numba.jit(nopython=True)
def myfunc():
    mylist = [-4, 6, 2, 0, -1]
    mylist.sort(key=lambda x: abs(x))
    return mylist  # [0, -1, 2, -4, 6]

However, in the following more complicated example, I get an error that I do not understand.

import numba
import numpy as np


@numba.jit(nopython=True)
def dist_from_mean(val, mu):
    return abs(val - mu)

@numba.jit(nopython=True)
def func():
    l = [1,7,3,9,10,-4,-2,0]
    avg_val = np.array(l).mean()
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
    return l

The error that it is reporting is the following:

Traceback (most recent call last):
  File "testitout.py", line 18, in <module>
    ret = func()
  File "/.../python3.6/site-packages/numba/core/dispatcher.py", line 415, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/.../python3.6/site-packages/numba/core/dispatcher.py", line 358, in error_rewrite
    reraise(type(e), e, None)
  File "/.../python3.6/site-packages/numba/core/utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'avg_val' in a function that will escape.

File "testitout.py", line 14:
def func():
    <source elided>
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
                                                ^

Do you know what is happening here?


回答1:


Do you know what is happening here?

By using the parameter nopython = True you deactivate the object mode and hence Numba isn't able to handle all values as Python objects (refer to: https://numba.pydata.org/numba-doc/latest/glossary.html#term-object-mode). (Reference is actually another post I coincidentally wrote today: How call a `@guvectorize` inside a `@guvectorize` in numba?)

@numba.jit(nopython=True)
def func():
    l = [1,7,3,9,10,-4,-2,0]
    avg_val = np.array(l).mean()
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
    return l

Anyhow, lambda is "too" complex for a numba jit function - at least when it's passed as an argument (compare https://github.com/numba/numba/issues/4481). With the nopython mode activated you can only use a limited amount of libraries - the full list can be found here: https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

That's why it throws the following error:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions) Cannot capture the non-constant value associated with variable 'avg_val' in a function that will escape.

Moreover, your referencing a jit accelerated function within another - when having nopython = True. That could be the source of the problem as well.

I'd highly suggest taking a look at the following tutorial: http://numba.pydata.org/numba-doc/latest/user/5minguide.html#will-numba-work-for-my-code; it should help you out with similar problems!


Further reading and sources:

  • https://github.com/numba/numba/issues/5120
  • http://numba.pydata.org/numba-doc/latest/user/5minguide.html#will-numba-work-for-my-code
  • TypingError: Failed in nopython mode pipeline (step: nopython frontend)
  • https://github.com/numba/numba/issues/4481
  • https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
  • How call a `@guvectorize` inside a `@guvectorize` in numba?


来源:https://stackoverflow.com/questions/63794635/how-do-i-pass-in-calculated-values-to-a-list-sort-using-numba-jit-in-python

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!