How to pass array pointer to Numba function?

前端 未结 1 1507
长情又很酷
长情又很酷 2020-12-21 21:01

I\'d like to create a Numba-compiled function that takes a pointer or the memory address of an array as an argument and does calculations on it, e.g., modifies the underlyin

相关标签:
1条回答
  • 2020-12-21 21:27

    Thanks to the great @stuartarchibald, I now have a working solution:

    import ctypes
    import numba as nb
    import numpy as np
    
    arr = np.arange(5).astype(np.double)  # create arbitrary numpy array
    print(arr)
    
    @nb.extending.intrinsic
    def address_as_void_pointer(typingctx, src):
        """ returns a void pointer from a given memory address """
        from numba.core import types, cgutils
        sig = types.voidptr(src)
    
        def codegen(cgctx, builder, sig, args):
            return builder.inttoptr(args[0], cgutils.voidptr_t)
        return sig, codegen
    
    addr = arr.ctypes.data
    
    @nb.njit
    def modify_data():
        """ a function taking the memory address of an array to modify it """
        data = nb.carray(address_as_void_pointer(addr), arr.shape, dtype=arr.dtype)
        data += 2
    
    modify_data()
    print(arr)
    

    The key is the new address_as_void_pointer function that turns a memory address (given as an int) into a pointer that is usable by numba.carray.

    0 讨论(0)
提交回复
热议问题