- User write the computation
- Use transforms like
to mark the function works on GPU - Module optimize
- Split a function to multiple sub-function if needed to make each sub-function GPU computable
on the module to compile the device kernel by NVRTCCodeGenLLVM_NVHost
on the module to compile the host function with LLVM JIT- Link the
code withCodeGenGpu_Dev
code, and finally get an executable.
The basic example, elementwise_mul
with dynamic shape M
Var M("M"); // dynamic dimension
Expr N(200);
Target target;
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});
auto C = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C");
auto [M_outer, M_inner] = C->stage()->Split(0, 32); // M/32, 32
In CINN, a LoweredFunc
might contains multiple root forloops, for example
function some_func(...) {
for (int i = 0; i < 100; i++) k(i);
for (int i = 0; i < 200; i++) k(i);
It is not possible to convert this function into a single CUDA kernel, we need to split this function into two kernels:
// two kernels
__global__ function kernel0(...) {
int i = threadIdx.x * threadDim.x;
__global__ function kernel1(...) {
int i = threadIdx.x * threadDim.x;
// host
function some_func(...) {
kernel0<<<100, ...>>>(...);
kernel1<<<200, ...>>>(...);
The CodeGenGpu_Dev
use NVRTC and CUDA driver compiles the module into several runtime functions.
To make the kernel able to trigger by LoweredFunc, several helper functions should be supported by cinn_runtime
void CudaRunKernel(const char* kernel_name, dim3 grid_dims, grid3 block_dims, void**args, CUstream stream);
The some_func
in C like
void some_func(cinn_pod_value_t* args, int num_args) {
cinn_buffer_t* _A = args[0];
cinn_buffer_t* _B = args[1];
cinn_buffer_t* _C = args[2];
const float* A = _A->host_memory;
const float* B = _B->host_memory;
float* C = _C->host_memory;
dim3 grid_dims1(20, 1, 1);
dim3 block_dims1(20, 1, 1);
dim3 grid_dims2(20, 1, 1);
dim3 block_dims2(20, 1, 1);
CudaRunKernel("kenrel1", grid_dims1, block_dims1, {A, B, C}, NULL);
CudaRunKernel("kenrel2", grid_dims2, block_dims2, {A, B, C}, NULL);
In the code above, the function some_func
is on the host, and calls by the JIT.