mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 19:57:52 +00:00
316 lines
11 KiB
Markdown
316 lines
11 KiB
Markdown
# Development Guide for JIT Kernels
|
|
|
|
## Environment Setup
|
|
|
|
We strongly recommend using `clangd` as the language server for JIT kernel development.
|
|
For Ubuntu/Debian, you can download clangd from [apt.llvm.org](https://apt.llvm.org/).
|
|
If you are using VS Code, we recommend installing the `clangd` extension for better IDE integration.
|
|
|
|
All JIT-related files are located in `python/sglang/jit_kernel`.
|
|
Unlike `sgl-kernel`, which compiles CUDA/C++ binaries ahead of time (AOT), just-in-time (JIT) kernels are compiled at runtime.
|
|
Consequently, a static `compile_commands.json` cannot be generated.
|
|
To enable code completion with `clangd`, run `python -m sglang.jit_kernel` to generate a `.clangd` configuration file in your current directory.
|
|
After generating the file, restart the clangd language server. It should now recognize all JIT kernel files.
|
|
|
|
## Code Structure
|
|
|
|
### C++ Implementation
|
|
|
|
C++ source code is located in `python/sglang/jit_kernel/csrc`.
|
|
Reusable functions should be placed in `python/sglang/jit_kernel/include`.
|
|
|
|
We use [tvm-ffi](https://github.com/apache/tvm-ffi) for efficient foreign language bindings.
|
|
Refer to the [documentation](https://tvm.apache.org/ffi/) for advanced usage, such as exporting C++ objects.
|
|
Typically, `tvm::ffi::TensorView` is sufficient for passing PyTorch Tensors from Python.
|
|
|
|
### Python Interface
|
|
|
|
Python interfaces are defined in `python/sglang/jit_kernel`.
|
|
The `load_jit` utility function in `python/sglang/jit_kernel/utils.py` loads and returns the compiled module.
|
|
To export a C++ function (e.g., `cpp_func`), pass `cuda_wrappers=[("func", "cpp_func")]` to `load_jit`.
|
|
The function can then be called in Python as `module.func`.
|
|
|
|
For caching compiled modules, prefer `sglang.jit_kernel.utils.cache_once` over `functools.lru_cache`.
|
|
`functools.lru_cache` is not compatible with `torch.compile`.
|
|
|
|
### C++ Utilities
|
|
|
|
The following C++ utilities are available:
|
|
|
|
#### Integer Range
|
|
|
|
Similar to PyTorch, we provide an `irange` function to represent an integer range.
|
|
|
|
```C++
|
|
#include <sgl_kernel/utils.h>
|
|
|
|
void test() {
|
|
for (auto i : host::irange(100)) { // [0, 100)
|
|
// do something
|
|
}
|
|
for (auto i : host::irange(0, 100)) { // [0, 100)
|
|
// do something
|
|
}
|
|
}
|
|
|
|
```
|
|
|
|
#### Runtime Checking
|
|
|
|
`RuntimeCheck` validates conditions at runtime. It accepts optional arguments for error reporting.
|
|
If the check fails, these arguments are output to aid debugging.
|
|
`RuntimeDeviceCheck` verifies the status of the last kernel launch.
|
|
|
|
```C++
|
|
#include <sgl_kernel/utils.h>
|
|
#include <sgl_kernel/utils.cuh>
|
|
|
|
void test() {
|
|
host::RuntimeCheck(1 + 1 == 2, 1 + 1, " != ", 2);
|
|
host::RuntimeDeviceCheck();
|
|
// check the provided `cudaError_t`
|
|
host::RuntimeDeviceCheck(cudaGetLastError());
|
|
}
|
|
|
|
```
|
|
|
|
#### Tensor Checking
|
|
|
|
`TensorMatcher` provides a readable way to validate and extract tensor shape information.
|
|
|
|
```cpp
|
|
#include <sgl_kernel/tensor.h>
|
|
|
|
void test(const tvm::ffi::TensorView k_cache, const tvm::ffi::TensorView v_cache) {
|
|
using namespace host;
|
|
|
|
auto D = SymbolicSize{"D"}; // cache dimension
|
|
auto N = SymbolicSize{"N"}; // kvcache stride
|
|
auto dtype = SymbolicDType{};
|
|
auto device = SymbolicDevice{};
|
|
|
|
TensorMatcher({-1, D}) //
|
|
.with_strides({N, 1})
|
|
.with_dtype<int32_t, int64_t>(dtype)
|
|
.with_device<kDLCUDA, kDLCPU>(device)
|
|
.verify(k_cache)
|
|
.verify(v_cache);
|
|
}
|
|
```
|
|
|
|
Configure the `TensorMatcher` with expected stride, dtype, and device properties before verification.
|
|
- If `with_strides` is omitted, the tensor is expected to be contiguous.
|
|
- Template arguments in `with_dtype` restrict the allowed data types.
|
|
- Template arguments in `with_device` restrict the allowed devices.
|
|
- Values passed to `with_xxx` methods enforce equality checks.
|
|
- Passing `-1` for size or stride allows matching any value.
|
|
|
|
A `Symbolic` variable must resolve to the same value across all verifications.
|
|
Use `.unwrap()` to retrieve the matched value after verification.
|
|
|
|
> Note: `TensorMatcher` is a temporary expression and should not be stored in a variable.
|
|
|
|
> Tip: Add `//` at the end of the `TensorMatcher` chain to enforce proper indentation.
|
|
|
|
#### Kernel Launching
|
|
|
|
`LaunchKernel::resolve_device` retrieves the current `cudaStream` from PyTorch.
|
|
Kernels can also be launched directly using `LaunchKernel`.
|
|
|
|
```cpp
|
|
#include <sgl_kernel/utils.cuh>
|
|
|
|
#include <dlpack/dlpack.h>
|
|
|
|
__global__ void kernel() {}
|
|
|
|
void test() {
|
|
const auto num_blocks = 1;
|
|
const auto num_threads = 32;
|
|
const auto dynamic_smem = 0;
|
|
|
|
DLDevice dev; // suppose this is initialized properly
|
|
host::LaunchKernel(num_blocks, num_threads, dev)(kernel);
|
|
|
|
cudaStream_t stream = host::LaunchKernel::resolve_device(dev);
|
|
host::LaunchKernel(num_blocks, num_threads, stream, dynamic_smem)(kernel);
|
|
}
|
|
|
|
```
|
|
|
|
## Add new kernels
|
|
|
|
This section walks through a complete, end-to-end example of adding a new JIT kernel to the system.
|
|
We use a simple add_constant kernel as a running example, which adds a constant integer value to every element of an input tensor.
|
|
|
|
Conceptually, the Python interface looks like this:
|
|
|
|
```python
|
|
def add_constant(src: torch.Tensor, c: int):
|
|
return src + c
|
|
```
|
|
|
|
### STEP 1: Write the C++ kernel
|
|
|
|
Write your CUDA kernel in [jit_kernel/csrc/add_constant.cuh](../../python/sglang/jit_kernel/csrc/add_constant.cuh). For demonstration purposes, we pass the constant value as a template parameter.
|
|
|
|
```cpp
|
|
#include <sgl_kernel/tensor.h> // For TensorMatcher, SymbolicSize, SymbolicDevice
|
|
#include <sgl_kernel/utils.cuh> // For LaunchKernel
|
|
#include <sgl_kernel/utils.h> // For div_ceil, RuntimeCheck
|
|
|
|
#include <dlpack/dlpack.h>
|
|
#include <tvm/ffi/container/tensor.h>
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
|
|
namespace {
|
|
|
|
template <int32_t kConstant>
|
|
__global__ void add_constant_kernel(int32_t* dst, const int32_t* src, size_t length) {
|
|
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx < length) {
|
|
dst[idx] = src[idx] + kConstant;
|
|
}
|
|
}
|
|
|
|
constexpr size_t kBlockSize = 256;
|
|
|
|
// You can also use struct with static method as an alternative
|
|
template <int32_t kConstant>
|
|
void add_constant(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {
|
|
using namespace host;
|
|
|
|
// 1. Validate input tensors
|
|
SymbolicSize N = {"num_elements"};
|
|
SymbolicDevice device_;
|
|
TensorMatcher({N}) // 1D tensor, must be contiguous
|
|
.with_dtype<int32_t>() // must be int32
|
|
.with_device<kDLCUDA>(device_) // must be on CUDA device
|
|
.verify(dst) // check tensor dst
|
|
.verify(src); // check tensor src
|
|
|
|
// 2. Extract required parameters, prepare for kernel launch
|
|
const size_t num_elements = N.unwrap();
|
|
const size_t grid_size = div_ceil(num_elements, kBlockSize);
|
|
const DLDevice device = device_.unwrap();
|
|
// some extra runtime checks using host::RuntimeCheck
|
|
RuntimeCheck(num_elements > 0, "We only support non-empty tensors, got num_elements = ", num_elements);
|
|
|
|
// 3. Launch the kernel. Error code will be automatically checked.
|
|
LaunchKernel(grid_size, kBlockSize, device /*, dynamic_smem*/)(
|
|
// kernel function
|
|
add_constant_kernel<kConstant>,
|
|
// kernel arguments
|
|
static_cast<int32_t*>(dst.data_ptr()),
|
|
static_cast<int32_t*>(src.data_ptr()),
|
|
num_elements);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
```
|
|
|
|
### STEP 2: Create Python Interfaces
|
|
|
|
Next, expose the kernel through a Python wrapper.
|
|
Create a new file at [jit_kernel/add_constant.py](../../python/sglang/jit_kernel/add_constant.py) and expose the needed interfaces.
|
|
|
|
```python
|
|
from __future__ import annotations
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
|
|
|
|
if TYPE_CHECKING:
|
|
from tvm_ffi.module import Module
|
|
|
|
|
|
@cache_once
|
|
def _jit_add_constant_module(constant: int) -> Module:
|
|
args = make_cpp_args(constant) # pass all the template argument
|
|
return load_jit(
|
|
"add_constant",
|
|
*args,
|
|
cuda_files=["add_constant.cuh"],
|
|
cuda_wrappers=[("add_constant", f"add_constant<{args}>")],
|
|
)
|
|
|
|
|
|
def add_constant(src: torch.Tensor, constant: int) -> torch.Tensor:
|
|
if not src.is_cuda:
|
|
raise RuntimeError("src must be a CUDA tensor")
|
|
if src.dtype != torch.int32:
|
|
raise RuntimeError(f"Unsupported dtype {src.dtype}. Supported: int32")
|
|
dst = torch.empty_like(src)
|
|
module = _jit_add_constant_module(constant)
|
|
module.add_constant(dst, src)
|
|
return dst
|
|
|
|
```
|
|
|
|
Keep the Python wrapper thin, but still validate the basic invariants such as device and dtype before dispatch. In the current JIT/FFI path, invalid tensors are not always rejected safely before launch.
|
|
|
|
### STEP 3: Use your kernel
|
|
|
|
Finally, import and use the kernel like a regular Python function:
|
|
|
|
```python
|
|
from sglang.jit_kernel.add_constant import add_constant
|
|
```
|
|
|
|
For a complete, runnable example, refer to [test_add_constant.py](../../python/sglang/jit_kernel/tests/test_add_constant.py).
|
|
|
|
## C++ Include Library Reference
|
|
|
|
The JIT kernel framework provides a set of reusable C++ headers in
|
|
`python/sglang/jit_kernel/include/sgl_kernel/`. Each header is designed
|
|
to be lightweight and self-contained. Below is a summary of each header
|
|
and its key APIs.
|
|
|
|
### Core Utilities
|
|
|
|
| Header | Namespace | Purpose |
|
|
|--------|-----------|---------|
|
|
| `utils.h` | `host` | Host-side essentials: `RuntimeCheck`, `Panic`, `div_ceil`, `irange` |
|
|
| `utils.cuh` | `device` / `host` | Type aliases (`fp16_t`, `bf16_t`, ...), `SGL_DEVICE` macro, PDL helpers, `LaunchKernel`, `RuntimeDeviceCheck` |
|
|
| `source_location.h` | (global) | Portable `std::source_location` wrapper for error reporting |
|
|
| `runtime.cuh` | `host::runtime` | CUDA runtime queries: `get_blocks_per_sm`, `get_sm_count`, `get_cc_major`, `get_runtime_version`, `get_available_dynamic_smem_per_block` |
|
|
|
|
### Tensor Validation
|
|
|
|
| Header | Namespace | Purpose |
|
|
|--------|-----------|---------|
|
|
| `tensor.h` | `host` | `TensorMatcher`, `SymbolicSize`, `SymbolicDType`, `SymbolicDevice` |
|
|
|
|
### Math & Type System
|
|
|
|
| Header | Namespace | Purpose |
|
|
|--------|-----------|---------|
|
|
| `math.cuh` | `device::math` | `max`, `min`, `abs`, `sqrt`, `rsqrt`, `exp`, `sin`, `cos`, constants |
|
|
| `type.cuh` | (global) / `device` | `dtype_trait<T>`, `packed_t<T>`, `device::cast<To>(from)` |
|
|
|
|
### Memory Access
|
|
|
|
| Header | Namespace | Purpose |
|
|
|--------|-----------|---------|
|
|
| `vec.cuh` | `device` | `AlignedVector<T, N>` - vectorized load/store (up to 128-bit; 256-bit requires Blackwell GPUs) |
|
|
| `tile.cuh` | `device::tile` | `Memory<T>` - cooperative tiled memory I/O (thread/warp/CTA) |
|
|
|
|
### Parallel Primitives
|
|
|
|
| Header | Namespace | Purpose |
|
|
|--------|-----------|---------|
|
|
| `warp.cuh` | `device::warp` | `reduce_sum`, `reduce_max` via `__shfl_xor_sync` |
|
|
| `cta.cuh` | `device::cta` | `reduce_max` across warps via shared memory |
|
|
| `atomic.cuh` | `device::atomic` | `max` - atomic float max (CUDA + ROCm fallback) |
|
|
|
|
### Reusable Kernel Templates
|
|
|
|
| Header | Namespace | Purpose |
|
|
|--------|-----------|---------|
|
|
| `impl/norm.cuh` | `host::norm` / `device::norm` | RMSNorm building blocks (warp & CTA paths, `StorageType`) |
|