mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-05 14:11:18 +00:00
610 lines
21 KiB
Plaintext
610 lines
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "578f2730",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Adding a kernel to the CUTLASS API\n",
|
|
"The CUTLASS API is designed to make it easy for users to add their own kernel\n",
|
|
"so that it can be discovered and run under the uniform API. We welcome contributions\n",
|
|
"toward the API by \"bringing your own kernel.\"\n",
|
|
"\n",
|
|
"This example shows how to add a CuTe DSL kernel to the CUTLASS API.\n",
|
|
"\n",
|
|
"## Bring your own implementation\n",
|
|
"Individuals wishing to add a CuTe DSL kernel to the CUTLASS API likely already\n",
|
|
"have the kernel written in CuTe DSL, but have not yet implemented the API's needed\n",
|
|
"interface. Within the API, we separate these components into the \"implementation\" --\n",
|
|
"the kernel written in CuTe DSL -- and the \"interface\" -- the definition of methods\n",
|
|
"a kernel needs to be used within the CUTLASS API.\n",
|
|
"\n",
|
|
"For example, consider the following implementation of a simple FP64 GEMM kernel implementation:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "5a64b0be",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from collections.abc import Callable\n",
|
|
"\n",
|
|
"import cuda.bindings.driver as cuda\n",
|
|
"\n",
|
|
"import cutlass\n",
|
|
"import cutlass.cute as cute\n",
|
|
"\n",
|
|
"\n",
|
|
"class F64GemmKernelImplementation:\n",
|
|
" def __init__(self, cta_tile_shape_mn: tuple[int, int]):\n",
|
|
" self.cta_tile_shape_mn = cta_tile_shape_mn\n",
|
|
"\n",
|
|
" @cute.jit\n",
|
|
" def __call__(\n",
|
|
" self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor, stream: cuda.CUstream\n",
|
|
" ):\n",
|
|
" l, m, n = out.shape\n",
|
|
" m_tiles = (m + self.cta_tile_shape_mn[0] - 1) // self.cta_tile_shape_mn[0]\n",
|
|
" n_tiles = (n + self.cta_tile_shape_mn[1] - 1) // self.cta_tile_shape_mn[1]\n",
|
|
"\n",
|
|
" grid = (m_tiles, n_tiles, l)\n",
|
|
" block = [self.cta_tile_shape_mn[0], self.cta_tile_shape_mn[1], 1]\n",
|
|
" self.kernel(a, b, out).launch(grid=grid, block=block, stream=stream)\n",
|
|
"\n",
|
|
" @cute.kernel\n",
|
|
" def kernel(self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor):\n",
|
|
" l, m, n = out.shape\n",
|
|
" k = a.shape[-1]\n",
|
|
" m_tile, n_tile, l_idx = cute.arch.block_idx()\n",
|
|
" tidx, tidy, _ = cute.arch.thread_idx()\n",
|
|
"\n",
|
|
" m_idx = m_tile * self.cta_tile_shape_mn[0] + tidx\n",
|
|
" n_idx = n_tile * self.cta_tile_shape_mn[1] + tidy\n",
|
|
"\n",
|
|
" if m_idx < m and n_idx < n:\n",
|
|
" out[l_idx, m_idx, n_idx] = cutlass.Float64(0)\n",
|
|
" for k_idx in range(k):\n",
|
|
" out[l_idx, m_idx, n_idx] += (\n",
|
|
" a[l_idx, m_idx, k_idx] * b[l_idx, k_idx, n_idx]\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "36a08d4b",
|
|
"metadata": {},
|
|
"source": [
|
|
"The implementation is configurable via a `cta_tile_shape_mn` argument, which\n",
|
|
"controls the size of blocks and tiles in the M and N modes. A simple `cute.jit` function\n",
|
|
"computes the grid and block size for the input problem based on `cta_tile_shape_mn`,\n",
|
|
"and launches the kernel. The `cute.kernel` itself simply has each thread compute a single\n",
|
|
"output element of the matrix by taking a dot product.\n",
|
|
"\n",
|
|
"This implementation is not performant, but is kept simple for illustrative purposes."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a5d0e661",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Defining interface methods\n",
|
|
"As it currently stands, this GEMM kernel implementation cannot be used via the\n",
|
|
"CUTLASS API because it does not implement interface methods. Specifically, kernels\n",
|
|
"within the CUTLASS API must inherit from and implement the `cutlass_api.Kernel`\n",
|
|
"abstract class. This class has methods needed for many common operations\n",
|
|
"performed when compiling and executing DSL kernels.\n",
|
|
"\n",
|
|
"Certain providers (i.e., DSLs), such as CuTe DSL, provide an additional layer atop the\n",
|
|
"`cutlass_api.Kernel` class to add utilities for kernels being written\n",
|
|
"via that provider. For example, the CuTe DSL provider in the CUTLASS API\n",
|
|
"defines `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`, which adds utilities surrounding\n",
|
|
"`cute.compile()` to add compile-time arguments needed for using TVM-FFI when\n",
|
|
"it is enabled.\n",
|
|
"\n",
|
|
"We will next walk through the steps in defining interface methods for this\n",
|
|
"implementation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "1a2da869",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import itertools\n",
|
|
"\n",
|
|
"import cutlass_api\n",
|
|
"from cutlass_api.arguments import GemmArguments\n",
|
|
"from cutlass_api.metadata import KernelMetadata\n",
|
|
"from cutlass_api.status import Status"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "86ae75cc",
|
|
"metadata": {},
|
|
"source": [
|
|
"We begin by defining a class to represent the kernel's interface.\n",
|
|
"As mentioned above, since this is a CuTe DSL kernel, our interface must\n",
|
|
"inherit from and implement `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`.\n",
|
|
"\n",
|
|
"The class must additionally be registered with the CuTe DSL provider\n",
|
|
"via the `@CuTeDSLProvider.register` decorator so that the class\n",
|
|
"can be considered when discovering kernels."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3a86d138",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@cutlass_api.providers.cutedsl.CuTeDSLProvider.register\n",
|
|
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
|
|
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
|
|
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
|
|
" def __init__(self, metadata: KernelMetadata):\n",
|
|
" pass\n",
|
|
"\n",
|
|
" def _run(\n",
|
|
" self,\n",
|
|
" args: GemmArguments,\n",
|
|
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
|
|
" stream,\n",
|
|
" workspace=None,\n",
|
|
" ):\n",
|
|
" pass\n",
|
|
"\n",
|
|
" def compile(\n",
|
|
" self, args: GemmArguments, cc: int = None\n",
|
|
" ) -> cutlass_api.artifact.CompiledArtifact:\n",
|
|
" pass\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def generate_kernels(\n",
|
|
" metadata_filter, epilogue_args=None, cc=None\n",
|
|
" ) -> list[\"F64GemmKernel\"]:\n",
|
|
" pass\n",
|
|
"\n",
|
|
" def _supports(self, args: GemmArguments) -> Status:\n",
|
|
" pass\n",
|
|
"\n",
|
|
" def get_workspace_size(self, args: GemmArguments) -> int:\n",
|
|
" pass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "327e9e7c",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `__init__` method of the class takes in a `KernelMetadata` object\n",
|
|
"from which it extracts the `cta_tile_shape_mn`. This is used to construct\n",
|
|
"the kernel implementation object. We will discuss later how the `KernelMetadata`\n",
|
|
"object passed in here is constructed:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "785d1882",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def __init__(self, metadata: KernelMetadata):\n",
|
|
" # Using Python-2-style super() because we're defining this method outside of the class definition.\n",
|
|
" super(F64GemmKernel, self).__init__(metadata)\n",
|
|
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
|
|
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "500a0030",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Defining interfaces for compilation and execution\n",
|
|
"The interfaces needed for compilation and execution are simple.\n",
|
|
"\n",
|
|
"The `compile` method simply constructs a placeholder stream object\n",
|
|
"and passes that and relevant arguments to `self.cute_compile`. This\n",
|
|
"is a utility defined in the `CuteDSLKernel` abstract class that\n",
|
|
"passes in compilation flags needed for certain options to `cute.compile`\n",
|
|
"(e.g., TVM-FFI). The result is wrapped as a `CompiledArtifact`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "63b4a129",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def compile(\n",
|
|
" self, args: GemmArguments, cc: int = None\n",
|
|
") -> cutlass_api.artifact.CompiledArtifact:\n",
|
|
" stream = cutlass.cute.runtime.make_fake_stream()\n",
|
|
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
|
|
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "023127fd",
|
|
"metadata": {},
|
|
"source": [
|
|
"Users define the `_run` method rather than the top-level `run` method\n",
|
|
"(no leading underscore) that is used in interacting with kernels. `_run` (1) extracts from `args`\n",
|
|
"the arguments needed to run the JIT function, and (2) calls the JIT function\n",
|
|
"passed in via `artifact` with these arguments."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "2ae7c009",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def _run(\n",
|
|
" self,\n",
|
|
" args: GemmArguments,\n",
|
|
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
|
|
" stream,\n",
|
|
" workspace=None,\n",
|
|
"):\n",
|
|
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
|
|
" compiled_gemm = artifact.compiled_obj\n",
|
|
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4052e5a0",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, since this kernel does not require any device workspace,\n",
|
|
"we give it a simple `get_workspace_size` method that always returns 0."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "968906ea",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_workspace_size(self, args: GemmArguments) -> int:\n",
|
|
" return 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e245a319",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Defining interfaces for kernel generation\n",
|
|
"We have implemented the interfaces needed for constructing the kernel\n",
|
|
"interface, compiling it, and running it. We now must implement methods for\n",
|
|
"generating the possible configurations of this kernel that the kernel\n",
|
|
"class itself supports. This will be used in kernel discovery (e.g., via\n",
|
|
"`cutlass_api.get_kernels()`).\n",
|
|
"\n",
|
|
"To do so, we write the `generate_kernels` method. This takes in a\n",
|
|
"binary function `metadata_filter`, epilogue arguments `epilogue_args`,\n",
|
|
"and a compute capability `cc`. It returns a list of all instances\n",
|
|
"of the kernel interface that support the `epilogue_args`, are compatible\n",
|
|
"with the given `cc`, and which pass the `metadata_filter`.\n",
|
|
"\n",
|
|
"The `Kernel` class is responsible for defining what valid possible configurations (instances) of it can exist.\n",
|
|
"In this example, the valid configurations involve a cross-product of row/column-major strides and two preset tile shapes.\n",
|
|
"We create a nested loop over these knobs and create a `KernelMetadata` corresponding to each unique configuration.\n",
|
|
"\n",
|
|
"The `generate_kernels` method must additionally filter the generated kernels by passing it through a `metadata_filter`.\n",
|
|
"This is a user-provided custom filter to filter generated metadata combinations. More information on `metadata_filter` is provided in other examples."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "47dc2f20",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@staticmethod\n",
|
|
"def generate_kernels(\n",
|
|
" metadata_filter: Callable[[KernelMetadata], bool],\n",
|
|
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
|
|
" cc: int = None,\n",
|
|
") -> list[\"F64GemmKernel\"]:\n",
|
|
" # The tile shapes this kernel supports/exposes\n",
|
|
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
|
|
"\n",
|
|
" if epilogue_args is not None:\n",
|
|
" return []\n",
|
|
"\n",
|
|
" row_major_stride = (0, 0, 1)\n",
|
|
" col_major_stride = (0, 1, 0)\n",
|
|
" stride_combos = list(\n",
|
|
" itertools.product([row_major_stride, col_major_stride], repeat=3)\n",
|
|
" )\n",
|
|
" divisibility = 1\n",
|
|
"\n",
|
|
" def stride_name(stride):\n",
|
|
" return \"T\" if stride == row_major_stride else \"N\"\n",
|
|
"\n",
|
|
" kernels = []\n",
|
|
" for tile_shape in supported_tile_shapes:\n",
|
|
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
|
|
" for stride_A, stride_B, stride_out in stride_combos:\n",
|
|
" # Create TensorAttributes for A, B, and out tensors\n",
|
|
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
|
" cutlass.Float64, stride_A, divisibility\n",
|
|
" )\n",
|
|
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
|
" cutlass.Float64, stride_B, divisibility\n",
|
|
" )\n",
|
|
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
|
" cutlass.Float64, stride_out, divisibility\n",
|
|
" )\n",
|
|
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",
|
|
" stride_A, stride_B, stride_out\n",
|
|
" )\n",
|
|
"\n",
|
|
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
|
|
"\n",
|
|
" metadata = KernelMetadata(\n",
|
|
" kernel_name=name,\n",
|
|
" kernel_class=F64GemmKernel,\n",
|
|
" operands=cutlass_api.metadata.GemmOperandsMetadata(\n",
|
|
" a_attrs, b_attrs, out_attrs, accumulator_type=cutlass.Float64\n",
|
|
" ),\n",
|
|
" design=design_metadata,\n",
|
|
" min_cc=0,\n",
|
|
" )\n",
|
|
"\n",
|
|
" if metadata_filter(metadata):\n",
|
|
" kernels.append(F64GemmKernel(metadata))\n",
|
|
"\n",
|
|
" return kernels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c7cdbc66",
|
|
"metadata": {},
|
|
"source": [
|
|
"We also add a method for indicating whether a kernel instance in question\n",
|
|
"supports a set of arguments. The top-level `Kernel.supports` method will\n",
|
|
"already verify that the `args` passed in match the metadata with which\n",
|
|
"this `Kernel` instance was constructed. Here, we define additional\n",
|
|
"checks specific to this kernel, such as that the kernel expects\n",
|
|
"all operands to be of rank 3:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "54067d47",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def _supports(self, args: GemmArguments) -> Status:\n",
|
|
" if not (\n",
|
|
" len(args.A.shape) == 3 # A should be (L, M, K)\n",
|
|
" and len(args.B.shape) == 3 # B should be (L, K, N)\n",
|
|
" and len(args.out.shape) == 3 # out should be (L, M, N)\n",
|
|
" ):\n",
|
|
" return Status.fail(\"All operands must be rank 3.\")\n",
|
|
" return Status.success()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "edaf2cba",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Assign methods to the class because we interspersed notebook markdown\n",
|
|
"# with the class definition. This is not needed in a real implementation.\n",
|
|
"F64GemmKernel.__init__ = __init__\n",
|
|
"F64GemmKernel.compile = compile\n",
|
|
"F64GemmKernel._run = _run\n",
|
|
"F64GemmKernel._supports = _supports\n",
|
|
"F64GemmKernel.generate_kernels = generate_kernels\n",
|
|
"F64GemmKernel.get_workspace_size = get_workspace_size"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c8fc84e9",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Discovering instances of the kernel and using them\n",
|
|
"The CUTLASS API is now prepared to discover instances of this\n",
|
|
"kernel interface just as was done in previous examples.\n",
|
|
"\n",
|
|
"We add a small modification of using a `metadata_filter`\n",
|
|
"to ensure that all returned kernels are instances of the\n",
|
|
"`F64GemmKernel` class we just implemented. This is needed\n",
|
|
"only for example/testing purposes."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "cec5431d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"torch.manual_seed(2025)\n",
|
|
"\n",
|
|
"L, M, N, K = 1, 256, 1024, 128\n",
|
|
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float64)\n",
|
|
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float64)\n",
|
|
"out = torch.empty(L, M, N, device=\"cuda\", dtype=torch.float64)\n",
|
|
"\n",
|
|
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
|
|
"\n",
|
|
"\n",
|
|
"def is_f64gemm_kernel(metadata):\n",
|
|
" return metadata.kernel_class == F64GemmKernel\n",
|
|
"\n",
|
|
"\n",
|
|
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "50e81a7d",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can print off the names of the first few kernels to see that\n",
|
|
"they come from our recently-added kernel."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "cdb92b5e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"F64GemmKernel_tile32x32_ttt\n",
|
|
"F64GemmKernel_tile16x16_ttt\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(kernels[0].metadata.kernel_name)\n",
|
|
"print(kernels[1].metadata.kernel_name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "697ee3c3",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can evaluate and test the correctness of an instance of our kernel:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "f5486244",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"kernels[0].run(args)\n",
|
|
"torch.testing.assert_close(out, A @ B)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8de96f7e",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also test the limits of our kernel's design space by providing a\n",
|
|
"metadata filter that expects a CTA tile size M of 256, which is not exposed\n",
|
|
"in the `generate_kernels` method of our recently-added kernel. We expect\n",
|
|
"no kernels of type `F64GemmKernel` to be returned."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "917c74e3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def my_filter(metadata):\n",
|
|
" return (\n",
|
|
" is_f64gemm_kernel(metadata)\n",
|
|
" and isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata)\n",
|
|
" and metadata.design.tile_shape[0] == 256\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
|
|
"\n",
|
|
"# No kernels should be found\n",
|
|
"assert len(kernels_ctam256) == 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "caa80a7d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## A note on contributing kernels to directory structure\n",
|
|
"This example showed how to define a kernel inline and add it to the\n",
|
|
"API for example purposes. This kernel doesn't necessarily need to live\n",
|
|
"within the API's source code.\n",
|
|
"\n",
|
|
"We welcome contributions of kernels that do live within the CUTLASS\n",
|
|
"API's repository as well.\n",
|
|
"\n",
|
|
"Kernels in the repository are organized based on the \"provider\" in which they are\n",
|
|
"authored (i.e., the DSL). All kernels corresponding to a given\n",
|
|
"provider live a directory corresponding to that provider under\n",
|
|
"`cutlass_api/providers`. For example, CuTe DSL kernels live\n",
|
|
"under `cutlass_api/providers/cutedsl`.\n",
|
|
"\n",
|
|
"Each provider can organize kernels differently. For CuTe DSL,\n",
|
|
"kernels are further split based on their logical operation,\n",
|
|
"with GEMM kernels under the `cutlass_api/providers/cutedsl/gemm`\n",
|
|
"directory.\n",
|
|
"\n",
|
|
"We recommend separating the implementation of the kernel from\n",
|
|
"its interface not just by using separate classes, as done in\n",
|
|
"this example, but also by separating the implementation and\n",
|
|
"interface into separate files. This makes it easier to update\n",
|
|
"each without affecting the other.\n",
|
|
"\n",
|
|
"For example, CuTe DSL GEMM kernels have the following organization:\n",
|
|
"```text\n",
|
|
"cutlass_api/\n",
|
|
" providers/\n",
|
|
" cutedsl/\n",
|
|
" gemm/\n",
|
|
" sm100_static_persistent.py\n",
|
|
" implementations/\n",
|
|
" sm100_static_persistent_impl.py\n",
|
|
"```"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|