Using Dictionaries with numba njit function

痞子三分冷 提交于 2019-12-07 14:45:22

问题


How to speed up a funtion with numba when input and return are dictionaries?

I'm familiar with using numba for functions that accept numbers and return arrays, like this:

@numba.jit('float64[:](int32,int32)',nopython=True)
def f(a, b):
    # returns array 1d array

Now I have a function that accepts and returns dictionaries. How can I apply numba here?

    def collocation(aeolus_data,val_data):

      ...

      return sample_aeolus, sample_valdata

回答1:


The support for Dictionary has now been added in Numba version 43.0. Although it quite limited (does not support list and set as key/values). You can however read the updated documentation here for more info. Here is an example

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

# First create a dictionary using Dict.empty()
# Specify the data types for both key and value pairs

# Dict with key as strings and values of type float array
dict_param1 = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64[:],
)

# Dict with keys as string and values of type float
dict_param2 = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,
)

# Type-expressions are currently not supported inside jit functions.
float_array = types.float64[:]

@njit
def add_values(d_param1, d_param2):
    # Make a result dictionary to store results
    # Dict with keys as string and values of type float array
    result_dict = Dict.empty(
        key_type=types.unicode_type,
        value_type=float_array,
    )

    for key in d_param1.keys():
      result_dict[key] = d_param1[key] + d_param2[key]

    return result_dict

dict_param1["hello"]  = np.asarray([1.5, 2.5, 3.5], dtype='f8')
dict_param1["world"]  = np.asarray([10.5, 20.5, 30.5], dtype='f8')

dict_param2["hello"]  = 1.5
dict_param2["world"]  = 10

final_dict = add_values(dict_param1, dict_param2)

print(final_dict)
# Output : {hello: [3. 4. 5.], world: [20.5 30.5 40.5]}

Link to Google colab notebook.

References:
- https://github.com/numba/numba/issues/3644
- https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#dict



来源:https://stackoverflow.com/questions/55078628/using-dictionaries-with-numba-njit-function

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!