JIT in JCuda, loading multiple ptx modules

依然范特西╮ 提交于 2019-11-28 12:40:05
talonmies

I had never written a single line of Java code until 30 minutes ago, let alone used JCUDA before, but an almost literal line-by-line translation of the native C++ code I gave you here seems to work perfectly:

import static jcuda.driver.JCudaDriver.*;
import java.io.*;
import jcuda.*;
import jcuda.driver.*;

public class JCudaRuntimeTest
{
    public static void main(String args[])
    {
        JCudaDriver.setExceptionsEnabled(true);

        cuInit(0);
        CUdevice device = new CUdevice();
        cuDeviceGet(device, 0);
        CUcontext context = new CUcontext();
        cuCtxCreate(context, 0, device);

        CUlinkState linkState = new CUlinkState();
        JITOptions jitOptions = new JITOptions();
        cuLinkCreate(jitOptions, linkState);

        String ptxFileName2 = "test_function.ptx";
        String ptxFileName1 = "test_kernel.ptx";

        cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_PTX, ptxFileName2, jitOptions);
        cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_PTX, ptxFileName1, jitOptions);

        long sz[] = new long[1];
        Pointer image = new Pointer();
        cuLinkComplete(linkState, image, sz);
        System.out.println("Pointer: " + image);
        System.out.println("CUBIN size: " + sz[0]);

        CUmodule module = new CUmodule();
        cuModuleLoadDataEx(module, image, 0, new int[0], Pointer.to(new int[0]));   
        cuLinkDestroy(linkState);

        CUfunction functionKernel = new CUfunction();
        String kernelname = "_Z6kernelPfS_S_S_";
        cuModuleGetFunction(functionKernel, module, kernelname);
        System.out.println("Function: " + functionKernel);
    }
}

which works like this:

> nvcc -ptx -arch=sm_21 test_function.cu
test_function.cu

> nvcc -ptx -arch=sm_21 test_kernel.cu
test_kernel.cu

> javac -cp ".;jcuda-0.7.0a.jar" JCudaRuntimeTest.java
> java -cp ".;jcuda-0.7.0a.jar" JCudaRuntimeTest
Pointer: Pointer[nativePointer=0xa5a13a8,byteOffset=0]
CUBIN size: 5924
Function: CUfunction[nativePointer=0xa588160]

The key here seems to be to use cuModuleLoadDataEx, noting that the return values from cuLinkComplete are a system pointer to the linked CUBIN and the size of the image returned as a long[]. As per the C++ code, the pointer is just passed directly to the module data load.

As a final comment, it would have been much simpler and easier if you had posted a proper repro case that could be been directly hacked on, rather than making me learn the rudiments of JCUDA and Java before I could create a useful repro case and get it to work. The documentation for JCUDA is basic, but complete, and against the working C++ example already provided, it only took a couple of minutes of reading to see how to do this.

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!