Files
cutlass/python/cutlass_api/examples/002_bring_your_own_kernel.ipynb
2025-12-16 10:00:46 -08:00

549 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "578f2730",
"metadata": {},
"source": [
"# Adding a kernel to the CUTLASS API\n",
"The CUTLASS API is designed to make it easy for users to add their own kernel\n",
"so that it can be discovered and run under the uniform API. We welcome contributions\n",
"toward the API by \"bringing your own kernel.\"\n",
"\n",
"This example shows how to add a CuTe DSL kernel to the CUTLASS API.\n",
"\n",
"## Bring your own implementation\n",
"Individuals wishing to add a CuTe DSL kernel to the CUTLASS API likely already\n",
"have the kernel written in CuTe DSL, but have not yet implemented the API's needed\n",
"interface. Within the API, we separate these components into the \"implementation\" --\n",
"the kernel written in CuTe DSL -- and the \"interface\" -- the definition of methods\n",
"a kernel needs to be used within the CUTLASS API.\n",
"\n",
"For example, consider the following implementation of a simple FP64 GEMM kernel implementation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a64b0be",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"\n",
"import cuda.bindings.driver as cuda\n",
"\n",
"import cutlass\n",
"import cutlass.cute as cute\n",
"\n",
"\n",
"class F64GemmKernelImplementation:\n",
" def __init__(self, cta_tile_shape_mn: tuple[int, int]):\n",
" self.cta_tile_shape_mn = cta_tile_shape_mn\n",
"\n",
" @cute.jit\n",
" def __call__(\n",
" self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor, stream: cuda.CUstream\n",
" ):\n",
" l, m, n = out.shape\n",
" m_tiles = (m + self.cta_tile_shape_mn[0] - 1) // self.cta_tile_shape_mn[0]\n",
" n_tiles = (n + self.cta_tile_shape_mn[1] - 1) // self.cta_tile_shape_mn[1]\n",
"\n",
" grid = (m_tiles, n_tiles, l)\n",
" block = [self.cta_tile_shape_mn[0], self.cta_tile_shape_mn[1], 1]\n",
" self.kernel(a, b, out).launch(grid=grid, block=block, stream=stream)\n",
"\n",
" @cute.kernel\n",
" def kernel(self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor):\n",
" l, m, n = out.shape\n",
" k = a.shape[-1]\n",
" m_tile, n_tile, l_idx = cute.arch.block_idx()\n",
" tidx, tidy, _ = cute.arch.thread_idx()\n",
"\n",
" m_idx = m_tile * self.cta_tile_shape_mn[0] + tidx\n",
" n_idx = n_tile * self.cta_tile_shape_mn[1] + tidy\n",
"\n",
" if m_idx < m and n_idx < n:\n",
" out[l_idx, m_idx, n_idx] = cutlass.Float64(0)\n",
" for k_idx in range(k):\n",
" out[l_idx, m_idx, n_idx] += (\n",
" a[l_idx, m_idx, k_idx] * b[l_idx, k_idx, n_idx]\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "36a08d4b",
"metadata": {},
"source": [
"The implementation is configurable via a `cta_tile_shape_mn` argument, which\n",
"controls the size of blocks and tiles in the M and N modes. A simple `cute.jit` function\n",
"computes the grid and block size for the input problem based on `cta_tile_shape_mn`,\n",
"and launches the kernel. The `cute.kernel` itself simply has each thread compute a single\n",
"output element of the matrix by taking a dot product.\n",
"\n",
"This implementation is not performant, but is kept simple for illustrative purposes."
]
},
{
"cell_type": "markdown",
"id": "a5d0e661",
"metadata": {},
"source": [
"## Defining interface methods\n",
"As it currently stands, this GEMM kernel implementation cannot be used via the\n",
"CUTLASS API because it does not implement interface methods. Specifically, kernels\n",
"within the CUTLASS API must inherit from and implement the `cutlass_api.Kernel`\n",
"abstract class. This class has methods needed for many common operations\n",
"performed when compiling and executing DSL kernels.\n",
"\n",
"Certain providers (i.e., DSLs), such as CuTe DSL, provide an additional layer atop the\n",
"`cutlass_api.Kernel` class to add utilities for kernels being written\n",
"via that provider. For example, the CuTe DSL provider in the CUTLASS API\n",
"defines `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`, which adds utilities surrounding\n",
"`cute.compile()` to add compile-time arguments needed for using TVM-FFI when\n",
"it is enabled.\n",
"\n",
"We will next walk through the steps in defining interface methods for this\n",
"implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a2da869",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"\n",
"import cutlass_api\n",
"from cutlass_api.arguments import GemmArguments\n",
"from cutlass_api.metadata import KernelMetadata\n",
"from cutlass_api.status import Status"
]
},
{
"cell_type": "markdown",
"id": "86ae75cc",
"metadata": {},
"source": [
"We begin by defining a class to represent the kernel's interface.\n",
"As mentioned above, since this is a CuTe DSL kernel, our interface must\n",
"inherit from and implement `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`.\n",
"\n",
"The class must additionally be registered with the CuTe DSL provider\n",
"via the `@CuTeDSLProvider.register` decorator so that the class\n",
"can be considered when discovering kernels."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a86d138",
"metadata": {},
"outputs": [],
"source": [
"@cutlass_api.providers.cutedsl.CuTeDSLProvider.register\n",
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
" def __init__(self, metadata: KernelMetadata): pass\n",
"\n",
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
"\n",
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
"\n",
" @staticmethod\n",
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
"\n",
" def _supports(self, args: GemmArguments) -> Status: pass\n",
"\n",
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
]
},
{
"cell_type": "markdown",
"id": "327e9e7c",
"metadata": {},
"source": [
"The `__init__` method of the class takes in a `KernelMetadata` object\n",
"from which it extracts the `cta_tile_shape_mn`. This is used to construct\n",
"the kernel implementation object. We will discuss later how the `KernelMetadata`\n",
"object passed in here is constructed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "785d1882",
"metadata": {},
"outputs": [],
"source": [
"def __init__(self, metadata: KernelMetadata):\n",
" self.metadata = metadata\n",
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
]
},
{
"cell_type": "markdown",
"id": "500a0030",
"metadata": {},
"source": [
"### Defining interfaces for compilation and execution\n",
"The interfaces needed for compilation and execution are simple.\n",
"\n",
"The `compile` method simply constructs a placeholder stream object\n",
"and passes that and relevant arguments to `self.cute_compile`. This\n",
"is a utility defined in the `CuteDSLKernel` abstract class that\n",
"passes in compilation flags needed for certain options to `cute.compile`\n",
"(e.g., TVM-FFI). The result is wrapped as a `CompiledArtifact`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63b4a129",
"metadata": {},
"outputs": [],
"source": [
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
" stream = cutlass.cute.runtime.make_fake_stream()\n",
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
]
},
{
"cell_type": "markdown",
"id": "023127fd",
"metadata": {},
"source": [
"Users define the `_run` method rather than the top-level `run` method\n",
"(no leading underscore) that is used in interacting with kernels. `_run` (1) extracts from `args`\n",
"the arguments needed to run the JIT function, and (2) calls the JIT function\n",
"passed in via `artifact` with these arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ae7c009",
"metadata": {},
"outputs": [],
"source": [
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
" compiled_gemm = artifact.compiled_obj\n",
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
]
},
{
"cell_type": "markdown",
"id": "4052e5a0",
"metadata": {},
"source": [
"Finally, since this kernel does not require any device workspace,\n",
"we give it a simple `get_workspace_size` method that always returns 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "968906ea",
"metadata": {},
"outputs": [],
"source": [
"def get_workspace_size(self, args: GemmArguments) -> int:\n",
" return 0"
]
},
{
"cell_type": "markdown",
"id": "e245a319",
"metadata": {},
"source": [
"### Defining interfaces for kernel generation\n",
"We have implemented the interfaces needed for constructing the kernel\n",
"interface, compiling it, and running it. We now must implement methods for\n",
"generating the possible configurations of this kernel that the kernel\n",
"class itself supports. This will be used in kernel discovery (e.g., via\n",
"`cutlass_api.get_kernels()`).\n",
"\n",
"To do so, we write the `generate_kernels` method. This takes in a\n",
"binary function `metadata_filter`, epilogue arguments `epilogue_args`,\n",
"and a compute capability `cc`. It returns a list of all instances\n",
"of the kernel interface that support the `epilogue_args`, are compatible\n",
"with the given `cc`, and which pass the `metadata_filter`.\n",
"\n",
"The `Kernel` class is responsible for defining what valid possible configurations (instances) of it can exist.\n",
"In this example, the valid configurations involve a cross-product of row/column-major strides and two preset tile shapes.\n",
"We create a nested loop over these knobs and create a `KernelMetadata` corresponding to each unique configuration.\n",
"\n",
"The `generate_kernels` method must additionally filter the generated kernels by passing it through a `metadata_filter`.\n",
"This is a user-provided custom filter to filter generated metadata combinations. More information on `metadata_filter` is provided in other examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47dc2f20",
"metadata": {},
"outputs": [],
"source": [
"@staticmethod\n",
"def generate_kernels(\n",
" metadata_filter: Callable[[KernelMetadata], bool],\n",
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
" cc: int = None,\n",
") -> list[\"F64GemmKernel\"]:\n",
"\n",
" # The tile shapes this kernel supports/exposes\n",
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
"\n",
" if epilogue_args is not None:\n",
" return []\n",
"\n",
" row_major_stride = (0, 0, 1)\n",
" col_major_stride = (0, 1, 0)\n",
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
" alignment = 1\n",
"\n",
" def stride_name(stride): \n",
" return \"T\" if stride == row_major_stride else \"N\"\n",
"\n",
" kernels = []\n",
" for tile_shape in supported_tile_shapes:\n",
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
" for stride_A, stride_B, stride_out in stride_combos:\n",
" # Create TensorAttributes for A, B, and out tensors\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
"\n",
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
"\n",
" metadata = KernelMetadata(\n",
" kernel_name=name,\n",
" kernel_class=F64GemmKernel,\n",
" operands=cutlass_api.metadata.GemmOperandsMetadata(\n",
" a_attrs, b_attrs, out_attrs, accumulator_type=cutlass.Float64\n",
" ),\n",
" design=design_metadata,\n",
" min_cc=0,\n",
" )\n",
"\n",
" if metadata_filter(metadata):\n",
" kernels.append(F64GemmKernel(metadata))\n",
"\n",
" return kernels"
]
},
{
"cell_type": "markdown",
"id": "c7cdbc66",
"metadata": {},
"source": [
"We also add a method for indicating whether a kernel instance in question\n",
"supports a set of arguments. The top-level `Kernel.supports` method will\n",
"already verify that the `args` passed in match the metadata with which\n",
"this `Kernel` instance was constructed. Here, we define additional\n",
"checks specific to this kernel, such as that the kernel expects\n",
"all operands to be of rank 3:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54067d47",
"metadata": {},
"outputs": [],
"source": [
"def _supports(self, args: GemmArguments) -> Status:\n",
" if not (\n",
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
" ):\n",
" return Status.fail(\"All operands must be rank 3.\")\n",
" return Status.success()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "edaf2cba",
"metadata": {},
"outputs": [],
"source": [
"# Assign methods to the class because we interspersed notebook markdown\n",
"# with the class definition. This is not needed in a real implementation.\n",
"F64GemmKernel.__init__ = __init__\n",
"F64GemmKernel.compile = compile\n",
"F64GemmKernel._run = _run\n",
"F64GemmKernel._supports = _supports\n",
"F64GemmKernel.generate_kernels = generate_kernels\n",
"F64GemmKernel.get_workspace_size = get_workspace_size"
]
},
{
"cell_type": "markdown",
"id": "c8fc84e9",
"metadata": {},
"source": [
"## Discovering instances of the kernel and using them\n",
"The CUTLASS API is now prepared to discover instances of this\n",
"kernel interface just as was done in previous examples.\n",
"\n",
"We add a small modification of using a `metadata_filter`\n",
"to ensure that all returned kernels are instances of the\n",
"`F64GemmKernel` class we just implemented. This is needed\n",
"only for example/testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cec5431d",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(2025)\n",
"\n",
"L, M, N, K = 1, 256, 1024, 128\n",
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float64)\n",
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float64)\n",
"out = torch.empty(L, M, N, device=\"cuda\", dtype=torch.float64)\n",
"\n",
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
"\n",
"def is_f64gemm_kernel(metadata):\n",
" return metadata.kernel_class == F64GemmKernel\n",
"\n",
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
]
},
{
"cell_type": "markdown",
"id": "50e81a7d",
"metadata": {},
"source": [
"We can print off the names of the first few kernels to see that\n",
"they come from our recently-added kernel."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdb92b5e",
"metadata": {},
"outputs": [],
"source": [
"print(kernels[0].metadata.kernel_name)\n",
"print(kernels[1].metadata.kernel_name)"
]
},
{
"cell_type": "markdown",
"id": "697ee3c3",
"metadata": {},
"source": [
"We can evaluate and test the correctness of an instance of our kernel:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5486244",
"metadata": {},
"outputs": [],
"source": [
"kernels[0].run(args)\n",
"torch.testing.assert_close(out, A @ B)"
]
},
{
"cell_type": "markdown",
"id": "8de96f7e",
"metadata": {},
"source": [
"We can also test the limits of our kernel's design space by providing a\n",
"metadata filter that expects a CTA tile size M of 256, which is not exposed\n",
"in the `generate_kernels` method of our recently-added kernel. We expect\n",
"no kernels of type `F64GemmKernel` to be returned."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "917c74e3",
"metadata": {},
"outputs": [],
"source": [
"def my_filter(metadata):\n",
" return (\n",
" is_f64gemm_kernel(metadata) and\n",
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
" metadata.design.tile_shape[0] == 256\n",
" )\n",
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
"\n",
"# No kernels should be found\n",
"assert len(kernels_ctam256) == 0"
]
},
{
"cell_type": "markdown",
"id": "caa80a7d",
"metadata": {},
"source": [
"## A note on contributing kernels to directory structure\n",
"This example showed how to define a kernel inline and add it to the\n",
"API for example purposes. This kernel doesn't necessarily need to live\n",
"within the API's source code.\n",
"\n",
"We welcome contributions of kernels that do live within the CUTLASS\n",
"API's repository as well.\n",
"\n",
"Kernels in the repository are organized based on the \"provider\" in which they are\n",
"authored (i.e., the DSL). All kernels corresponding to a given\n",
"provider live a directory corresponding to that provider under\n",
"`cutlass_api/providers`. For example, CuTe DSL kernels live\n",
"under `cutlass_api/providers/cutedsl`.\n",
"\n",
"Each provider can organize kernels differently. For CuTe DSL,\n",
"kernels are further split based on their logical operation,\n",
"with GEMM kernels under the `cutlass_api/providers/cutedsl/gemm`\n",
"directory.\n",
"\n",
"We recommend separating the implementation of the kernel from\n",
"its interface not just by using separate classes, as done in\n",
"this example, but also by separating the implementation and\n",
"interface into separate files. This makes it easier to update\n",
"each without affecting the other.\n",
"\n",
"For example, CuTe DSL GEMM kernels have the following organization:\n",
"```text\n",
"cutlass_api/\n",
" providers/\n",
" cutedsl/\n",
" gemm/\n",
" sm100_static_persistent.py\n",
" implementations/\n",
" sm100_static_persistent_impl.py\n",
"```"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}