JIT in JCuda, loading multiple ptx modules

后端 未结 1 1311
攒了一身酷
攒了一身酷 2020-12-11 07:33

I said in this question that I had some problem loading ptx modules in JCuda and after @talonmies\'s idea, I implemented a JCuda version of his solution to load multiple ptx

相关标签:
1条回答
  • 2020-12-11 07:55

    I had never written a single line of Java code until 30 minutes ago, let alone used JCUDA before, but an almost literal line-by-line translation of the native C++ code I gave you here seems to work perfectly:

    import static jcuda.driver.JCudaDriver.*;
    import java.io.*;
    import jcuda.*;
    import jcuda.driver.*;
    
    public class JCudaRuntimeTest
    {
        public static void main(String args[])
        {
            JCudaDriver.setExceptionsEnabled(true);
    
            cuInit(0);
            CUdevice device = new CUdevice();
            cuDeviceGet(device, 0);
            CUcontext context = new CUcontext();
            cuCtxCreate(context, 0, device);
    
            CUlinkState linkState = new CUlinkState();
            JITOptions jitOptions = new JITOptions();
            cuLinkCreate(jitOptions, linkState);
    
            String ptxFileName2 = "test_function.ptx";
            String ptxFileName1 = "test_kernel.ptx";
    
            cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_PTX, ptxFileName2, jitOptions);
            cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_PTX, ptxFileName1, jitOptions);
    
            long sz[] = new long[1];
            Pointer image = new Pointer();
            cuLinkComplete(linkState, image, sz);
            System.out.println("Pointer: " + image);
            System.out.println("CUBIN size: " + sz[0]);
    
            CUmodule module = new CUmodule();
            cuModuleLoadDataEx(module, image, 0, new int[0], Pointer.to(new int[0]));   
            cuLinkDestroy(linkState);
    
            CUfunction functionKernel = new CUfunction();
            String kernelname = "_Z6kernelPfS_S_S_";
            cuModuleGetFunction(functionKernel, module, kernelname);
            System.out.println("Function: " + functionKernel);
        }
    }
    

    which works like this:

    > nvcc -ptx -arch=sm_21 test_function.cu
    test_function.cu
    
    > nvcc -ptx -arch=sm_21 test_kernel.cu
    test_kernel.cu
    
    > javac -cp ".;jcuda-0.7.0a.jar" JCudaRuntimeTest.java
    > java -cp ".;jcuda-0.7.0a.jar" JCudaRuntimeTest
    Pointer: Pointer[nativePointer=0xa5a13a8,byteOffset=0]
    CUBIN size: 5924
    Function: CUfunction[nativePointer=0xa588160]
    

    The key here seems to be to use cuModuleLoadDataEx, noting that the return values from cuLinkComplete are a system pointer to the linked CUBIN and the size of the image returned as a long[]. As per the C++ code, the pointer is just passed directly to the module data load.

    As a final comment, it would have been much simpler and easier if you had posted a proper repro case that could be been directly hacked on, rather than making me learn the rudiments of JCUDA and Java before I could create a useful repro case and get it to work. The documentation for JCUDA is basic, but complete, and against the working C++ example already provided, it only took a couple of minutes of reading to see how to do this.

    0 讨论(0)
提交回复
热议问题