Initial commit

This commit is contained in:
jkosaian
2025-12-16 10:00:46 -08:00
parent d4e16f5d4e
commit ead2fbfe13
81 changed files with 19407 additions and 0 deletions

View File

@@ -0,0 +1,558 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3dd45ef2",
"metadata": {},
"source": [
"# Basic GEMM using CUTLASS Python API"
]
},
{
"cell_type": "markdown",
"id": "4709aa60",
"metadata": {},
"source": [
"The CUTLASS API provides a consistent, uniform interface for discovering, compiling, and running GPU kernels from various DSL sources.\n",
"\n",
"This notebook walks through a minimal GEMM (Generalized Matrix-Matrix Multiplication) example, and introduces the core concepts of the API."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f878d960-d175-4d84-b978-88afbd318850",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import cutlass\n",
"\n",
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "db91dab6",
"metadata": {},
"source": [
"## Running your first kernel"
]
},
{
"cell_type": "markdown",
"id": "7b7b87b0",
"metadata": {},
"source": [
"### Setting up arguments\n",
"\n",
"CUTLASS API has first-class support for PyTorch tensors. We start by creating torch tensors that will be operands to a matrix multiplication."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f550c4ea",
"metadata": {},
"outputs": [],
"source": [
"M, N, K, L = 128, 256, 64, 2\n",
"ab_type = torch.float16\n",
"out_type = torch.float32\n",
"acc_type = torch.float32\n",
"\n",
"A = torch.randint(-1, 2, (L, M, K), device=\"cuda\", dtype=ab_type)\n",
"B = torch.randint(-1, 2, (L, K, N), device=\"cuda\", dtype=ab_type)\n",
"out = torch.empty((L, M, N), device=\"cuda\", dtype=out_type)\n",
"\n",
"reference = (A @ B).to(out.dtype)"
]
},
{
"cell_type": "markdown",
"id": "6b6cb805",
"metadata": {},
"source": [
"We then create a `GemmArguments` object. This object specifies:\n",
"1. what logical operation do we want to perform (a GEMM)\n",
"2. on which operands we want to perform that operation (`A`, `B`, `out` as declared above)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b57690df",
"metadata": {},
"outputs": [],
"source": [
"args = cutlass_api.arguments.GemmArguments(A=A, B=B, out=out, accumulator_type=acc_type)"
]
},
{
"cell_type": "markdown",
"id": "67e5ddcf",
"metadata": {},
"source": [
"### Kernel discovery\n",
"\n",
"We now need to find kernels that can perform the operation we expressed in `args`.\n",
"\n",
"The simplest way to do so is to use `get_kernels(args)`. It searches a set of kernels pre-registered in the library, and returns the subset of those kernels which can successfully run our given `args`.\n",
"\n",
"Any of these kernels will be functionally equivalent -- they may have different design or performance characteristics. We arbitrarily pick the first of the returned kernels to execute here"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9872ad66",
"metadata": {},
"outputs": [],
"source": [
"kernels = cutlass_api.get_kernels(args)\n",
"assert kernels, \"No kernels found for the given arguments!\"\n",
"\n",
"kernel = kernels[0]"
]
},
{
"cell_type": "markdown",
"id": "4c17693e",
"metadata": {},
"source": [
"#### Run the kernel\n",
"\n",
"Running the kernel is as simple as `kernel.run(args)`.\n",
"\n",
"This implicitly JIT-compiles the kernel, and launches it on the GPU device using our given arguments."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "baf4588a",
"metadata": {},
"outputs": [],
"source": [
"kernel.run(args)\n",
"\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "4d7ad85b",
"metadata": {},
"source": [
"One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid\n",
"JIT compilation on future invocations. Additional details related to this will be\n",
"described below."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06f9f844",
"metadata": {},
"outputs": [],
"source": [
"artifact = kernel.compile(args)\n",
"kernel.run(args, compiled_artifact=artifact)\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "630e9e4b",
"metadata": {},
"source": [
"---\n",
"\n",
"---\n",
"\n",
"### Understanding the core interfaces"
]
},
{
"cell_type": "markdown",
"id": "2d8b8e94",
"metadata": {},
"source": [
"#### 1. `RuntimeArguments` / `GemmArguments`\n",
"\n",
"`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. \n",
"This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.\n",
"\n",
"We provide builtin subtypes of `RuntimeArguments` for common operations (e.g. GEMM, Elementwise ops; more later).\n",
"\n",
"For instance, `GemmArguments` is a type of `RuntimeArguments`:\n",
"\n",
"```python\n",
"@dataclass\n",
"class GemmArguments(RuntimeArguments):\n",
" A: TensorLike\n",
" B: TensorLike\n",
" out: TensorLike\n",
" accumulator_type: NumericLike\n",
"```\n",
"\n",
"`GemmArguments` conveys:\n",
"* We want to perform a dense GEMM operation (`out = A @ B`)\n",
"* We want to perform it for operands in `A, B, out`, with intermediate results stored as `accumulator_type`\n",
"* We can optionally set a custom epilogue that is fused on top of the base GEMM. Some kernels also support some runtime performance controls which can be specified here. These will be discussed in detail in other tutorials.\n",
"\n",
"It is a kernel-agnostic way to specify the desired functionality.\n",
"\n",
"`RuntimeArguments` can be constructed from any `TensorLike` object. This includes `torch.Tensor`, `cute.Tensor`, or any other DLPack-compatible tensors."
]
},
{
"cell_type": "markdown",
"id": "e7eda0dd",
"metadata": {},
"source": [
"#### 2. Kernel Discovery\n",
"\n",
"There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.\n",
"\n",
"This includes kernels for various operations (GEMM, Elementwise operations, ...), which implement various algorithms & architecture features. Within the same implementation, there are several instances or configurations of it with different combinations of operand types, layouts, tile sizes, etc.\n",
"\n",
"In the previous step, we used `GemmArguments` to specify our desired GEMM in a kernel-agnostic way. Now we find kernels that can fulfill that functionality. A subset of the available kernels will perform GEMM, and a subset of _those_ will support the properties of specific operands we are currently using."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3b737131",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A total of 107616 kernel instances are available.\n",
"Of these, 350 support the given arguments.\n",
"Picked kernel with name: cutedsl.PersistentDenseGemmKernel_sm100_ttt_AFloat16_BFloat16_outFloat32_accFloat32_2cta_cluster2x1x1_tile128x32x256_tma_store\n"
]
}
],
"source": [
"# get_kernels() fetches all kernels when called without args\n",
"all_kernels = cutlass_api.get_kernels()\n",
"print(f\"A total of {len(all_kernels)} kernel instances are available.\")\n",
"\n",
"# we can limit the search to kernels supporting given args\n",
"kernels = cutlass_api.get_kernels(args)\n",
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
"\n",
"kernel = kernels[0]\n",
"print(f\"Picked kernel with name: {kernel.metadata.kernel_name}\")"
]
},
{
"cell_type": "markdown",
"id": "252a4d38",
"metadata": {},
"source": [
"#### 3. `Kernel` execution"
]
},
{
"cell_type": "markdown",
"id": "574d004b",
"metadata": {},
"source": [
"Once we have selected a kernel, we are now ready to execute it. We previously showed the simplest way to do this is `kernel.run(args)`.\n",
"\n",
"This method does the following:\n",
"* verify that the kernel supports the given `args`\n",
"* JIT-compile the kernel\n",
"* launch the compiled kernel function\n",
"\n",
"Users can do these steps individually for more control:"
]
},
{
"cell_type": "markdown",
"id": "e8945aa6",
"metadata": {},
"source": [
"* `kernel.supports(args)` checks if the kernel supports the given `args`\n",
" * this is relevant if the kernel was not picked just for these `args`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "159fd610",
"metadata": {},
"outputs": [],
"source": [
"supported = kernel.supports(args)\n",
"assert supported"
]
},
{
"cell_type": "markdown",
"id": "948689a8",
"metadata": {},
"source": [
"If the arguments are not supported by this kernel, `supports` returns a `Status` object explaining the error."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2cfc9ea7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Operand `A` is unsupported: Expected element type Float16, got BFloat16\n"
]
}
],
"source": [
"unsupported_args = cutlass_api.arguments.GemmArguments(\n",
" A=A.to(torch.bfloat16), B=B, out=out, accumulator_type=acc_type\n",
")\n",
"if not (status := kernel.supports(unsupported_args)):\n",
" print(status.error)\n",
"\n",
"assert not status"
]
},
{
"cell_type": "markdown",
"id": "c2db8f20",
"metadata": {},
"source": [
"* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
"\n",
"This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).\n",
"\n",
"For just-in-time compilation, we can use the compiled artifact straightaway.\n",
"In the future, we will support optionally serializing it for ahead-of-time compilation and deserialized in a different context.\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "02e79eb8",
"metadata": {},
"outputs": [],
"source": [
"compiled_artifact = kernel.compile(args)"
]
},
{
"cell_type": "markdown",
"id": "4dfb8d51",
"metadata": {},
"source": [
"* `kernel.run(args)` launches the compiled kernel function. This example uses:\n",
" * the precompiled artifact\n",
" * a custom stream to launch to\n",
" * bypasses the supports check already performed above (`assume_supported_args=True`)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "02398bf0",
"metadata": {},
"outputs": [],
"source": [
"# zero the output to avoid testing stale output\n",
"out.zero_()\n",
"\n",
"kernel.run(\n",
" args,\n",
" compiled_artifact,\n",
" stream=torch.cuda.Stream(),\n",
" assume_supported_args=True,\n",
")\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "f67eeb8f",
"metadata": {},
"source": [
"Some kernels may also require a device \"workspace\". This is an additional buffer needed by some kernels for book-keeping, temporary results, etc.\n",
"Its size can be queried using `kernel.get_workspace_size(args)`. Most kernels will have a workspace size of 0.\n",
"If a kernel does have a non-zero workspace size, an additional buffer of at least that size must be provided. Without it, the kernel behavior is undefined."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5b2e2d56",
"metadata": {},
"outputs": [],
"source": [
"workspace_size = kernel.get_workspace_size(args)\n",
"workspace = torch.empty(workspace_size, device=\"cuda\", dtype=torch.int8)\n",
"\n",
"out.zero_()\n",
"kernel.run(args, compiled_artifact, stream=torch.cuda.Stream(), workspace=workspace)\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "baffaf12",
"metadata": {},
"source": [
"### Advanced: Filtering on Metadata"
]
},
{
"cell_type": "markdown",
"id": "86dd521a",
"metadata": {},
"source": [
"Using `RuntimeArguments` to search for supporting kernels is a convenient way to discover kernels: users directly specify their desired functionality, and `get_kernels()` finds the supporting kernels.\n",
"It covers all logical operands of any operation, as well as (in later examples) epilogue fusions, and performance controls.\n",
"\n",
"However, there may be cases where users want more advanced ways to query kernels. These could be:\n",
"* when the desired properties may not be expressed in runtime controls\n",
" * the simplest scenario may be if you're searching searching for a kernel with a specific name, a specific class, etc.\n",
" * searching for kernel's static properties such as tile size, cluster size, etc.\n",
"* when the `RuntimeArguments` are not available or you want to generate & pre-compile a broader set of kernels\n",
"\n",
"For such cases, we provide a more advanced filtering based on `KernelMetadata`"
]
},
{
"cell_type": "markdown",
"id": "ef54ae50",
"metadata": {},
"source": [
"`KernelMetadata` captures a wide variety of properties of a `Kernel`.\n",
"\n",
"These are properties of a kernel's functional support (like operand types, layouts, alignments), as well as architectural/design choices & performance characteristics (like tilze size, scheduling characteristics).\n",
"\n",
"Different kernels may use different sub-classes of `metadata.operands`, `metadata.design`, `metadata.epilogue` for flexibility, which can also identify their characteristics.\n",
"\n",
"```python\n",
"@dataclass\n",
"class KernelMetadata:\n",
" kernel_name: str\n",
" kernel_class: type[\"Kernel\"]\n",
" min_cc: int\n",
" operands: OperandsMetadata\n",
" design: DesignMetadata | None = None\n",
" epilogue: EpilogueMetadata | None = None\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "f9943888",
"metadata": {},
"source": [
"Every unique kernel instance can be distinguished by its metadata.\n",
"It can be used in filtering for kernels in addition to the `RuntimeArguments`, by providing a custom `metadata_filter`.\n",
"\n",
"Here, we get all kernels that support `args`, and have `metadata.design` of type `Sm100DesignMetadata`.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8717ac89",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 350 kernels which support args & have Sm100DesignMetadata\n"
]
}
],
"source": [
"kernels = cutlass_api.get_kernels(\n",
" args,\n",
" metadata_filter=lambda metadata: isinstance(\n",
" metadata.design, cutlass_api.metadata.Sm100DesignMetadata\n",
" ),\n",
")\n",
"print(f\"Found {len(kernels)} kernels which support args & have Sm100DesignMetadata\")"
]
},
{
"cell_type": "markdown",
"id": "1d1f9124",
"metadata": {},
"source": [
"We can construct more advanced filters by leveraging duck-typing.\n",
"Additionally, we can get all the kernels that match our filter, rather than supporting a fully-defined set of arguments.\n",
"This could be useful to pre-generate large set of kernels not targeted to any one problem."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a76ec20f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 9400 matching kernels\n"
]
}
],
"source": [
"def a_more_complex_filter(metadata: cutlass_api.metadata.KernelMetadata) -> bool:\n",
" \"\"\"\n",
" Find all GEMM kernels that support Float16 A and 2-CTA MMA\n",
" \"\"\"\n",
" # Only look at GEMM kernels\n",
" if not isinstance(metadata.operands, cutlass_api.metadata.GemmOperandsMetadata):\n",
" return False\n",
" # Only look at kernels with A-type F16\n",
" if metadata.operands.A.dtype != cutlass.Float16:\n",
" return False\n",
" # Only look at kernels with tile_shape[0] == 128\n",
" if getattr(metadata.design, \"tile_shape\", [None])[0] != 128:\n",
" return False\n",
" return True\n",
"\n",
"\n",
"# Look ma, no args! Fetch all kernels that match the filter,\n",
"# instead of supporting a complete set of args\n",
"kernels = cutlass_api.get_kernels(\n",
" args=None,\n",
" metadata_filter=a_more_complex_filter,\n",
")\n",
"print(f\"Found {len(kernels)} matching kernels\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,518 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a31330d3",
"metadata": {},
"source": [
"# Custom epilogue fusions for GEMMs\n",
"\n",
"Note: this notebook requires a GPU with compute capability 100 or 103:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb450878",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
" import sys\n",
"\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "154e9d59",
"metadata": {},
"source": [
"The CUTLASS API provides flexible epilogue fusion support by allowing for the specification of an epilogue via high-level tensor operations that one would like to compose with an operation.\n",
"\n",
"For those familiar with the legacy CUTLASS Python API's [epilogue visitor tree frontend](https://github.com/NVIDIA/cutlass/blob/a2439551c765c5393aebe557ee75d3a0412d2211/examples/python/deprecated/04_epilogue_visitor.ipynb), much of the interface is shared.\n",
"\n",
"The CUTLASS API enables one to express an epilogue using a function operating at the `torch.Tensor`-level, and has tooling to automatically add this to kernels supporting the provided function. \n",
"\n",
"For example, in PyTorch one might write the following to compute a GEMM + epilogue:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6d77d53",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(2025)\n",
"\n",
"L, M, N, K = 1, 1024, 1024, 1024\n",
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float16)\n",
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float16)\n",
"C = torch.randn(L, M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"def my_epilogue(accum, C, alpha, beta, extra_scalar):\n",
" Aux = (alpha * accum) + (beta * C)\n",
" D = extra_scalar * Aux\n",
" return D, Aux\n",
"\n",
"alpha, beta, extra_scalar = 1.0, 2.0, 0.5\n",
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)\n"
]
},
{
"cell_type": "markdown",
"id": "66ee4dd1",
"metadata": {},
"source": [
"The CUTLASS API allows the same epilogue function `my_epilogue` to be used in GEMMs provided by the API.\n",
"\n",
"To do so, one defines `EpilogueArguments` consisting of the epilogue function to compute (or a string representation of it) along with arguments corresponding to each input and output of the function (except for `accum`):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f079d9d6",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"from cutlass_api.arguments import GemmArguments, EpilogueArguments\n",
"\n",
"# Allocate buffers for D and Aux\n",
"D_, Aux_ = [torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)]\n",
"\n",
"epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)\n"
]
},
{
"cell_type": "markdown",
"id": "97ef8e8a",
"metadata": {},
"source": [
"These arguments can be added to `GemmArguments` and passed in to `get_kernels()` for use when retrieving compatible kernels:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60215c4e",
"metadata": {},
"outputs": [],
"source": [
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args)\n",
"assert len(kernels) > 0\n"
]
},
{
"cell_type": "markdown",
"id": "b0a7f9a2",
"metadata": {},
"source": [
"Each of the kernels returned by `get_kernels` can be compiled and executed just the same with these new arguments, as it was in examples without\n",
"epilogue fusion. For example, using the first kernel:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "150f3296",
"metadata": {},
"outputs": [],
"source": [
"kernels[0].run(args)\n",
"\n",
"torch.testing.assert_close(D, D_)\n",
"torch.testing.assert_close(Aux, Aux_)\n"
]
},
{
"cell_type": "markdown",
"id": "f1a826e3",
"metadata": {},
"source": [
"## How the epilogue fusion API works\n",
"To support specifying an epilogue via a Python function, a kernel needs some mechanism to:\n",
"1. Detect the operations in the epilogue function\n",
"2. Determine if the kernel can support the operations\n",
"3. Emit code to perform these operations within the kernel\n",
"\n",
"Step 1 listed above does not depend on the kernel and its implementation (e.g., DSL), while steps 2 and 3 depend on the kernel and/or its implementation.\n",
"\n",
"Thus, the CUTLASS API separates these components so that step 1 takes place at the API level and steps 2 and 3 take place in the kernel. This process is visualized below. We will walk through each step in greater detail.\n",
"\n",
"```python\n",
" +------------------------------------+\n",
" | def epi(accum, alpha, beta, C): |\n",
" | D = (accum * alpha) + (beta * C) | 1. Define epilogue via a Python function\n",
" | return D |\n",
" +------------------------------------+\n",
" |\n",
" |\n",
" |\n",
" GemmArguments(..., 2. Pass epilogue function, operands, and outputs\n",
" epilogue=EpilogueArguments( to EpilogueArguments constructor,\n",
" epi, alpha=alpha, beta=beta, C=C)) and add this to the GemmArguments. Under the\n",
" | hood, this parses the Python AST of the\n",
" | epilogue function to produce a DAG of load,\n",
" | store, and compute nodes.\n",
" V\n",
" +-----------------------------------------+ \n",
" | Intermediate DAG representation |\n",
" | =============================== |\n",
" | |\n",
" | Store() |\n",
" | | |\n",
" | Add() |\n",
" | / \\ |\n",
" | / \\ |\n",
" | / \\ |\n",
" | Mul() Mul() |\n",
" | / \\ / \\ |\n",
" | AccFetch() | Load(C) \\ |\n",
" | | \\ |\n",
" | Load(alpha) Load(beta) |\n",
" | |\n",
" +-----------------------------------------+\n",
" / | \\\n",
" / | \\ 3. Individual kernel classes use the DAG representation\n",
" / | \\ to determine if the kernel class supports the DAG.\n",
" Kernel 0 Kernel 1 Kernel 2 If so, the kernel class emits DSL-level operations\n",
" epilogue epilogue epilogue needed to compute the epilogue DAG alongside the\n",
" emitter emitter emitter basic operation of the kernel (e.g., GEMM).\n",
" | | |\n",
" | | |\n",
" V V V\n",
"```\n",
"\n",
"### Defining an epilogue via a Python function\n",
"Epilogue fusion patterns are defined by users in Python functions that perform Tensor-level operations -- using a `torch.Tensor` (for example) resulting from matrix multiplication, the function must be able to compute the desired results of the epilogue.\n",
"\n",
"The structure of these functions is as follows:\n",
"```python\n",
"def custom_epi_name(accum, *args) -> Union[TensorType, tuple[TensorType]]:\n",
" \"\"\"\n",
" :param accum: result of matrix multiplication, convolution, etc. before the epilogue\n",
" :type accum: TensorType\n",
" :param args: additional arguments to be used in the epilogue (e.g., aux tensors)\n",
" :type args: list[Union[TensorType, ScalarType]]\n",
"\n",
" :returns: at least one tensor resulting from the operation of the epilogue\n",
" :rtype: Union[TensorType, tuple[TensorType]]\n",
" \"\"\"\n",
" # Do some compute\n",
" return D # and potentially other values\n",
"```\n",
"\n",
"The user defines a custom epilogue via a Python function that **must** do at least the following:\n",
"1. Take in a first positional argument named `accum` that represents the result of operation just before the epilogue is to be performed. For example, in a GEMM, `accum = A @ B`.\n",
"2. Return at least one tensor that results from computing the epilogue. Currently, the return list must contain at least one output named `D`, though this constraint may be loosened in the future.\n",
"\n",
"Each additional argument following `accum` in the function definition is expected to be either a Tensor or scalar to be loaded. Each variable in the return statement represents a Tensor or scalar to be stored. The underlying implementation of the epilogue in the kernel will determine how operands are loaded and stored.\n",
"\n",
"Compute operations are represented in static single assignment (SSA) form.\n",
"This means that each variable can be assigned exactly once.\n",
"Operations currently supported ares:\n",
"* Tensor-tensor elementwise addition, subtraction, multiplication, and division\n",
"* Scalar broadcasts via addition, subtraction, multiplication, and division\n",
"* Predefined elementwise activation functions (e.g., ReLU, sigmoid, tanh)\n",
"\n",
"Operations that are not yet supported include:\n",
"* Row/column broadcasts (planned to be added soon)\n",
"* Reductions (planned to be added soon)\n",
"* Binary minimum and maximum functions (planned to be added soon)\n",
"If attempting to use these operations will result in no kernels being found in the call to `get_kernels`.\n",
"\n",
"Violations to SSA or use of unexpected operators will be flagged with an exception when parsing the AST of the custom epilogue.\n",
"\n",
"Examples of epilogues fitting these patterns are given below. We will show full, runnable examples at the end of this notebook.\n",
"```python\n",
"def relu_aux_store(accum, alpha, C):\n",
" # Note that the function definition itself does not indicate the types and\n",
" # ranks of alpha and C. Thus, one cannot tell whether the epilogue is performing\n",
" # broadcasts or elementwise operations until actual arguments or metadata are\n",
" # provided to the epilogue. See below for details.\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta of 2.0\n",
" D = relu(F)\n",
" return D, F\n",
"\n",
"def aux_normalize(accum, aux):\n",
" D = accum / aux\n",
" return D\n",
"```\n",
"\n",
"Additional information about each operand and output must be provided by the user when constructing `EpilogueArguments`, as we will discuss below. This additional information is necessary for fully defining the operations being performed -- without knowledge of whether `alpha` is a scalar or a Tensor, we cannot determine whether multiplication by `alpha` is a broadcasted or elementwise operation.\n",
"\n",
"### Constructing epilogue arguments\n",
"`EpilogueArguments` encapsulate the arguments needed to determine the functional operation of a fused epilogue.\n",
"\n",
"A user must provide in the construction of `EpilogueArguments` tensors for all operands and outputs of the epilogue. However, unlike arguments for basic operations (e.g., GEMM), the full set of operands needed to be specified for an epilogue pattern depends upon the custom epilogue defined by the user.\n",
"\n",
"Therefore, `EpilogueArguments` is defined generically as taking in an `epilogue_fn` and additional `kwargs`. Under the hood, the AST for `epilogue_fn` is parsed to determine the operands and outputs of the epilogue. The user is required to provide in `kwargs` Tensors or scalars for all operands and outputs in the provided epilogue.\n",
"\n",
"For example, with an epilogue of:\n",
"```python\n",
"def my_epi(accum, alpha, C, beta):\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D, F\n",
"```\n",
"A user would need to construct epilogue arguments as follows:\n",
"```python\n",
"epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)\n",
"```\n",
"\n",
"After verifying that all required operands and outputs are present, the constructor to `EpilogueArguments` will perform additional passes on the AST of `epilogue_fn` using the provided inputs to generate an internal DAG representing the epilogue. This DAG structure is attached to `EpilogueArguments` for use as they are passed through a call to `get_kernels`.\n",
"\n",
"### Discovering kernels that support the epilogue pattern\n",
"\n",
"The call to `get_kernels(args)` will return any kernels that support the provided `GemmArguments`.\n",
"Since the `GemmArguments` constructed above now include `EpilogueArguments`, returned kernels must support the provided epilogue.\n",
"\n",
"Under the hood of `get_kernels()`, each `Kernel` class will determine in its `generate_kernels()` method whether it supports the provided `EpilogueArguments`.\n",
"It can do so by traversing the DAG that resulted from the construction of `EpilogueArguments` to find the operations that compose the epilogue.\n",
"Assuming that the `Kernel` can support the DAG, it must then add to the source for the kernel any operations needed to support the DAG.\n",
"An example of how this is done generically for an SM100 CuTe DSL GEMM is provided in `sm100_static_persistent_efc.py`.\n",
"\n",
"## Example epilogues\n",
"We now provide various examples of adding custom epilogues to GEMM kernels targeting SM100. A broader set of epilogue examples are available in `test_gemm_epilogue_fusion.py`.\n",
"\n",
"### Auxiliary input and output tensors"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "171ac178",
"metadata": {},
"outputs": [],
"source": [
"from cutlass_api.fusion.activation import relu\n",
"\n",
"def relu_aux_store(accum, alpha, C):\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
" D = relu(F)\n",
" return D, F\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 3.0\n",
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"\n",
"epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"D_ref, F_ref = relu_aux_store(A @ B, alpha, C)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
]
},
{
"cell_type": "markdown",
"id": "f947b403",
"metadata": {},
"source": [
"### Keyword functions and returning accumulator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62c2b49b",
"metadata": {},
"outputs": [],
"source": [
"def relu_scale_return_acc(accum, alpha, beta, C, scale):\n",
" F = relu((accum * alpha) + (C * beta))\n",
" D = F * scale\n",
" return D, F, accum\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 1.0\n",
"beta = 2.0\n",
"scale = 0.5\n",
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"accum = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float32)\n",
"\n",
"epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"D_ref, F_ref, accum_ref = relu_scale_return_acc(A @ B, alpha, beta, C, scale)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n",
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))\n"
]
},
{
"cell_type": "markdown",
"id": "c641911f",
"metadata": {},
"source": [
"### Passing a string representation of the function\n",
"`EpilogueArguments` can additionally be constructed using a string representation of the epilogue function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5987bf44",
"metadata": {},
"outputs": [],
"source": [
"epi_str = \"def epi(accum, alpha, beta, C): F = (accum * alpha) + (C * beta); D = relu(F); return D, F\"\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 1.0\n",
"beta = 0.5\n",
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"\n",
"epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"F_ref = (A @ B) * alpha + (C * beta)\n",
"D_ref = torch.relu(F_ref)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
]
},
{
"cell_type": "markdown",
"id": "e26a58a2",
"metadata": {},
"source": [
"### Failure examples\n",
"The following are examples of constructing `EpilogueArguments` that are expected to fail."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e3d0c89",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Epilogues must take in an accumulator\n",
"####################################################\n",
"def fail_missing_accum(alpha, beta, C):\n",
" D = (C * beta)\n",
" return D\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"accum must be an input to the epilogue function\"\n",
" print(e)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48a359f7",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Epilogues must return an output named D\n",
"####################################################\n",
"def fail_missing_D(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" return F\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []\"\n",
" print(e)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49d9ee94",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Epilogues must use single-static assignment (SSA)\n",
"####################################################\n",
"def fail_ssa(accum):\n",
" tmp = accum * 2.0\n",
" # Redefine tmp, which violates SSA form.\n",
" tmp = tmp - 1.0\n",
" D = tmp / 4.0\n",
" return D, tmp\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"Variable 'tmp' cannot be defined twice.\"\n",
" print(e)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "871bb727",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Must provide all operands and outputs to\n",
"# EpilogueArguments\n",
"####################################################\n",
"def my_epi(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D\n",
"\n",
"try:\n",
" # Missing D\n",
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n",
"\n",
"try:\n",
" # Missing alpha\n",
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,548 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "578f2730",
"metadata": {},
"source": [
"# Adding a kernel to the CUTLASS API\n",
"The CUTLASS API is designed to make it easy for users to add their own kernel\n",
"so that it can be discovered and run under the uniform API. We welcome contributions\n",
"toward the API by \"bringing your own kernel.\"\n",
"\n",
"This example shows how to add a CuTe DSL kernel to the CUTLASS API.\n",
"\n",
"## Bring your own implementation\n",
"Individuals wishing to add a CuTe DSL kernel to the CUTLASS API likely already\n",
"have the kernel written in CuTe DSL, but have not yet implemented the API's needed\n",
"interface. Within the API, we separate these components into the \"implementation\" --\n",
"the kernel written in CuTe DSL -- and the \"interface\" -- the definition of methods\n",
"a kernel needs to be used within the CUTLASS API.\n",
"\n",
"For example, consider the following implementation of a simple FP64 GEMM kernel implementation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a64b0be",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"\n",
"import cuda.bindings.driver as cuda\n",
"\n",
"import cutlass\n",
"import cutlass.cute as cute\n",
"\n",
"\n",
"class F64GemmKernelImplementation:\n",
" def __init__(self, cta_tile_shape_mn: tuple[int, int]):\n",
" self.cta_tile_shape_mn = cta_tile_shape_mn\n",
"\n",
" @cute.jit\n",
" def __call__(\n",
" self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor, stream: cuda.CUstream\n",
" ):\n",
" l, m, n = out.shape\n",
" m_tiles = (m + self.cta_tile_shape_mn[0] - 1) // self.cta_tile_shape_mn[0]\n",
" n_tiles = (n + self.cta_tile_shape_mn[1] - 1) // self.cta_tile_shape_mn[1]\n",
"\n",
" grid = (m_tiles, n_tiles, l)\n",
" block = [self.cta_tile_shape_mn[0], self.cta_tile_shape_mn[1], 1]\n",
" self.kernel(a, b, out).launch(grid=grid, block=block, stream=stream)\n",
"\n",
" @cute.kernel\n",
" def kernel(self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor):\n",
" l, m, n = out.shape\n",
" k = a.shape[-1]\n",
" m_tile, n_tile, l_idx = cute.arch.block_idx()\n",
" tidx, tidy, _ = cute.arch.thread_idx()\n",
"\n",
" m_idx = m_tile * self.cta_tile_shape_mn[0] + tidx\n",
" n_idx = n_tile * self.cta_tile_shape_mn[1] + tidy\n",
"\n",
" if m_idx < m and n_idx < n:\n",
" out[l_idx, m_idx, n_idx] = cutlass.Float64(0)\n",
" for k_idx in range(k):\n",
" out[l_idx, m_idx, n_idx] += (\n",
" a[l_idx, m_idx, k_idx] * b[l_idx, k_idx, n_idx]\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "36a08d4b",
"metadata": {},
"source": [
"The implementation is configurable via a `cta_tile_shape_mn` argument, which\n",
"controls the size of blocks and tiles in the M and N modes. A simple `cute.jit` function\n",
"computes the grid and block size for the input problem based on `cta_tile_shape_mn`,\n",
"and launches the kernel. The `cute.kernel` itself simply has each thread compute a single\n",
"output element of the matrix by taking a dot product.\n",
"\n",
"This implementation is not performant, but is kept simple for illustrative purposes."
]
},
{
"cell_type": "markdown",
"id": "a5d0e661",
"metadata": {},
"source": [
"## Defining interface methods\n",
"As it currently stands, this GEMM kernel implementation cannot be used via the\n",
"CUTLASS API because it does not implement interface methods. Specifically, kernels\n",
"within the CUTLASS API must inherit from and implement the `cutlass_api.Kernel`\n",
"abstract class. This class has methods needed for many common operations\n",
"performed when compiling and executing DSL kernels.\n",
"\n",
"Certain providers (i.e., DSLs), such as CuTe DSL, provide an additional layer atop the\n",
"`cutlass_api.Kernel` class to add utilities for kernels being written\n",
"via that provider. For example, the CuTe DSL provider in the CUTLASS API\n",
"defines `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`, which adds utilities surrounding\n",
"`cute.compile()` to add compile-time arguments needed for using TVM-FFI when\n",
"it is enabled.\n",
"\n",
"We will next walk through the steps in defining interface methods for this\n",
"implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a2da869",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"\n",
"import cutlass_api\n",
"from cutlass_api.arguments import GemmArguments\n",
"from cutlass_api.metadata import KernelMetadata\n",
"from cutlass_api.status import Status"
]
},
{
"cell_type": "markdown",
"id": "86ae75cc",
"metadata": {},
"source": [
"We begin by defining a class to represent the kernel's interface.\n",
"As mentioned above, since this is a CuTe DSL kernel, our interface must\n",
"inherit from and implement `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`.\n",
"\n",
"The class must additionally be registered with the CuTe DSL provider\n",
"via the `@CuTeDSLProvider.register` decorator so that the class\n",
"can be considered when discovering kernels."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a86d138",
"metadata": {},
"outputs": [],
"source": [
"@cutlass_api.providers.cutedsl.CuTeDSLProvider.register\n",
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
" def __init__(self, metadata: KernelMetadata): pass\n",
"\n",
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
"\n",
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
"\n",
" @staticmethod\n",
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
"\n",
" def _supports(self, args: GemmArguments) -> Status: pass\n",
"\n",
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
]
},
{
"cell_type": "markdown",
"id": "327e9e7c",
"metadata": {},
"source": [
"The `__init__` method of the class takes in a `KernelMetadata` object\n",
"from which it extracts the `cta_tile_shape_mn`. This is used to construct\n",
"the kernel implementation object. We will discuss later how the `KernelMetadata`\n",
"object passed in here is constructed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "785d1882",
"metadata": {},
"outputs": [],
"source": [
"def __init__(self, metadata: KernelMetadata):\n",
" self.metadata = metadata\n",
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
]
},
{
"cell_type": "markdown",
"id": "500a0030",
"metadata": {},
"source": [
"### Defining interfaces for compilation and execution\n",
"The interfaces needed for compilation and execution are simple.\n",
"\n",
"The `compile` method simply constructs a placeholder stream object\n",
"and passes that and relevant arguments to `self.cute_compile`. This\n",
"is a utility defined in the `CuteDSLKernel` abstract class that\n",
"passes in compilation flags needed for certain options to `cute.compile`\n",
"(e.g., TVM-FFI). The result is wrapped as a `CompiledArtifact`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63b4a129",
"metadata": {},
"outputs": [],
"source": [
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
" stream = cutlass.cute.runtime.make_fake_stream()\n",
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
]
},
{
"cell_type": "markdown",
"id": "023127fd",
"metadata": {},
"source": [
"Users define the `_run` method rather than the top-level `run` method\n",
"(no leading underscore) that is used in interacting with kernels. `_run` (1) extracts from `args`\n",
"the arguments needed to run the JIT function, and (2) calls the JIT function\n",
"passed in via `artifact` with these arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ae7c009",
"metadata": {},
"outputs": [],
"source": [
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
" compiled_gemm = artifact.compiled_obj\n",
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
]
},
{
"cell_type": "markdown",
"id": "4052e5a0",
"metadata": {},
"source": [
"Finally, since this kernel does not require any device workspace,\n",
"we give it a simple `get_workspace_size` method that always returns 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "968906ea",
"metadata": {},
"outputs": [],
"source": [
"def get_workspace_size(self, args: GemmArguments) -> int:\n",
" return 0"
]
},
{
"cell_type": "markdown",
"id": "e245a319",
"metadata": {},
"source": [
"### Defining interfaces for kernel generation\n",
"We have implemented the interfaces needed for constructing the kernel\n",
"interface, compiling it, and running it. We now must implement methods for\n",
"generating the possible configurations of this kernel that the kernel\n",
"class itself supports. This will be used in kernel discovery (e.g., via\n",
"`cutlass_api.get_kernels()`).\n",
"\n",
"To do so, we write the `generate_kernels` method. This takes in a\n",
"binary function `metadata_filter`, epilogue arguments `epilogue_args`,\n",
"and a compute capability `cc`. It returns a list of all instances\n",
"of the kernel interface that support the `epilogue_args`, are compatible\n",
"with the given `cc`, and which pass the `metadata_filter`.\n",
"\n",
"The `Kernel` class is responsible for defining what valid possible configurations (instances) of it can exist.\n",
"In this example, the valid configurations involve a cross-product of row/column-major strides and two preset tile shapes.\n",
"We create a nested loop over these knobs and create a `KernelMetadata` corresponding to each unique configuration.\n",
"\n",
"The `generate_kernels` method must additionally filter the generated kernels by passing it through a `metadata_filter`.\n",
"This is a user-provided custom filter to filter generated metadata combinations. More information on `metadata_filter` is provided in other examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47dc2f20",
"metadata": {},
"outputs": [],
"source": [
"@staticmethod\n",
"def generate_kernels(\n",
" metadata_filter: Callable[[KernelMetadata], bool],\n",
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
" cc: int = None,\n",
") -> list[\"F64GemmKernel\"]:\n",
"\n",
" # The tile shapes this kernel supports/exposes\n",
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
"\n",
" if epilogue_args is not None:\n",
" return []\n",
"\n",
" row_major_stride = (0, 0, 1)\n",
" col_major_stride = (0, 1, 0)\n",
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
" alignment = 1\n",
"\n",
" def stride_name(stride): \n",
" return \"T\" if stride == row_major_stride else \"N\"\n",
"\n",
" kernels = []\n",
" for tile_shape in supported_tile_shapes:\n",
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
" for stride_A, stride_B, stride_out in stride_combos:\n",
" # Create TensorAttributes for A, B, and out tensors\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
"\n",
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
"\n",
" metadata = KernelMetadata(\n",
" kernel_name=name,\n",
" kernel_class=F64GemmKernel,\n",
" operands=cutlass_api.metadata.GemmOperandsMetadata(\n",
" a_attrs, b_attrs, out_attrs, accumulator_type=cutlass.Float64\n",
" ),\n",
" design=design_metadata,\n",
" min_cc=0,\n",
" )\n",
"\n",
" if metadata_filter(metadata):\n",
" kernels.append(F64GemmKernel(metadata))\n",
"\n",
" return kernels"
]
},
{
"cell_type": "markdown",
"id": "c7cdbc66",
"metadata": {},
"source": [
"We also add a method for indicating whether a kernel instance in question\n",
"supports a set of arguments. The top-level `Kernel.supports` method will\n",
"already verify that the `args` passed in match the metadata with which\n",
"this `Kernel` instance was constructed. Here, we define additional\n",
"checks specific to this kernel, such as that the kernel expects\n",
"all operands to be of rank 3:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54067d47",
"metadata": {},
"outputs": [],
"source": [
"def _supports(self, args: GemmArguments) -> Status:\n",
" if not (\n",
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
" ):\n",
" return Status.fail(\"All operands must be rank 3.\")\n",
" return Status.success()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "edaf2cba",
"metadata": {},
"outputs": [],
"source": [
"# Assign methods to the class because we interspersed notebook markdown\n",
"# with the class definition. This is not needed in a real implementation.\n",
"F64GemmKernel.__init__ = __init__\n",
"F64GemmKernel.compile = compile\n",
"F64GemmKernel._run = _run\n",
"F64GemmKernel._supports = _supports\n",
"F64GemmKernel.generate_kernels = generate_kernels\n",
"F64GemmKernel.get_workspace_size = get_workspace_size"
]
},
{
"cell_type": "markdown",
"id": "c8fc84e9",
"metadata": {},
"source": [
"## Discovering instances of the kernel and using them\n",
"The CUTLASS API is now prepared to discover instances of this\n",
"kernel interface just as was done in previous examples.\n",
"\n",
"We add a small modification of using a `metadata_filter`\n",
"to ensure that all returned kernels are instances of the\n",
"`F64GemmKernel` class we just implemented. This is needed\n",
"only for example/testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cec5431d",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(2025)\n",
"\n",
"L, M, N, K = 1, 256, 1024, 128\n",
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float64)\n",
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float64)\n",
"out = torch.empty(L, M, N, device=\"cuda\", dtype=torch.float64)\n",
"\n",
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
"\n",
"def is_f64gemm_kernel(metadata):\n",
" return metadata.kernel_class == F64GemmKernel\n",
"\n",
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
]
},
{
"cell_type": "markdown",
"id": "50e81a7d",
"metadata": {},
"source": [
"We can print off the names of the first few kernels to see that\n",
"they come from our recently-added kernel."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdb92b5e",
"metadata": {},
"outputs": [],
"source": [
"print(kernels[0].metadata.kernel_name)\n",
"print(kernels[1].metadata.kernel_name)"
]
},
{
"cell_type": "markdown",
"id": "697ee3c3",
"metadata": {},
"source": [
"We can evaluate and test the correctness of an instance of our kernel:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5486244",
"metadata": {},
"outputs": [],
"source": [
"kernels[0].run(args)\n",
"torch.testing.assert_close(out, A @ B)"
]
},
{
"cell_type": "markdown",
"id": "8de96f7e",
"metadata": {},
"source": [
"We can also test the limits of our kernel's design space by providing a\n",
"metadata filter that expects a CTA tile size M of 256, which is not exposed\n",
"in the `generate_kernels` method of our recently-added kernel. We expect\n",
"no kernels of type `F64GemmKernel` to be returned."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "917c74e3",
"metadata": {},
"outputs": [],
"source": [
"def my_filter(metadata):\n",
" return (\n",
" is_f64gemm_kernel(metadata) and\n",
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
" metadata.design.tile_shape[0] == 256\n",
" )\n",
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
"\n",
"# No kernels should be found\n",
"assert len(kernels_ctam256) == 0"
]
},
{
"cell_type": "markdown",
"id": "caa80a7d",
"metadata": {},
"source": [
"## A note on contributing kernels to directory structure\n",
"This example showed how to define a kernel inline and add it to the\n",
"API for example purposes. This kernel doesn't necessarily need to live\n",
"within the API's source code.\n",
"\n",
"We welcome contributions of kernels that do live within the CUTLASS\n",
"API's repository as well.\n",
"\n",
"Kernels in the repository are organized based on the \"provider\" in which they are\n",
"authored (i.e., the DSL). All kernels corresponding to a given\n",
"provider live a directory corresponding to that provider under\n",
"`cutlass_api/providers`. For example, CuTe DSL kernels live\n",
"under `cutlass_api/providers/cutedsl`.\n",
"\n",
"Each provider can organize kernels differently. For CuTe DSL,\n",
"kernels are further split based on their logical operation,\n",
"with GEMM kernels under the `cutlass_api/providers/cutedsl/gemm`\n",
"directory.\n",
"\n",
"We recommend separating the implementation of the kernel from\n",
"its interface not just by using separate classes, as done in\n",
"this example, but also by separating the implementation and\n",
"interface into separate files. This makes it easier to update\n",
"each without affecting the other.\n",
"\n",
"For example, CuTe DSL GEMM kernels have the following organization:\n",
"```text\n",
"cutlass_api/\n",
" providers/\n",
" cutedsl/\n",
" gemm/\n",
" sm100_static_persistent.py\n",
" implementations/\n",
" sm100_static_persistent_impl.py\n",
"```"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,521 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "f97e61c9",
"metadata": {},
"source": [
"# Best practices for reducing host-side latency"
]
},
{
"cell_type": "markdown",
"id": "a7a9c63c",
"metadata": {},
"source": [
"Overall performance depends on both device performance (i.e., that of the kernel) and host performance (i.e., that of the runtime).\n",
"This notebook focuses on the latter: techniques to minimize any overheads incurred from the CUTLASS API and underlying\n",
"DSL runtimes.\n",
"\n",
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e3ca9e40",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import torch\n",
"import cutlass_api"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "efaac09c",
"metadata": {},
"outputs": [],
"source": [
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
" import sys\n",
"\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "40de11ce",
"metadata": {},
"source": [
"We start with boilerplate initial setup to create tensors and pick a kernel.\n",
"\n",
"For the purposes of this notebook, we use a very small GEMM size of M=N=K=128\n",
"and L=1. This small size is chosen to magnify the impact of host latency on\n",
"end-to-end performance so as to better illustrate the effect of the techniques\n",
"described below."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b8c44947",
"metadata": {},
"outputs": [],
"source": [
"warmup_iterations = 10\n",
"profiling_iterations = 100\n",
"total_iterations = warmup_iterations + profiling_iterations\n",
"\n",
"# Use a small problem size to showcase host overheads\n",
"L, M, N, K = 1, 128, 128, 128\n",
"\n",
"# We use different operands in each iteration. Though not particularly relevant for\n",
"# host latency, this is a best practice when benchmarking GPU kernels to avoid\n",
"# unrealistic caching effects.\n",
"As = [torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"Bs = [torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"outs = [torch.empty((M, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"\n",
"# Construct arguments outside of the benchmarking loop. We will later also consider\n",
"# cases in which they are constructed inside the benchmarking loop.\n",
"args = [cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16) for i in range(total_iterations)]\n",
"\n",
"references = [(As[i] @ Bs[i]).to(outs[i].dtype) for i in range(total_iterations)]\n",
"\n",
"kernels = cutlass_api.get_kernels(args[0], cc=100)\n",
"\n",
"assert len(kernels) > 0\n",
"kernel = kernels[0]"
]
},
{
"cell_type": "markdown",
"id": "f2e7eece",
"metadata": {},
"source": [
"We next set up a basic benchmarking routine."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2472eafa",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations):\n",
" total_it = warmup_it + profiling_it\n",
" assert total_it <= total_iterations, f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
" # warmup\n",
" rets = [None] * total_it\n",
" for i in range(warmup_it):\n",
" rets[i] = code(i)\n",
" torch.cuda.synchronize()\n",
"\n",
" start = time.time()\n",
" for i in range(profiling_it):\n",
" idx = warmup_it + i\n",
" rets[idx] = code(idx)\n",
" torch.cuda.synchronize()\n",
" end = time.time()\n",
"\n",
" avg_time = (end - start) / profiling_it\n",
" print(f\"[{label:<30}] avg of {profiling_it} iterations: {avg_time:1.3e} seconds\")\n",
" return avg_time, rets"
]
},
{
"cell_type": "markdown",
"id": "4909a76b",
"metadata": {},
"source": [
"We now describe techniques for reducing host latency:\n",
"* Compile once, run many times\n",
"* Bypassing checks for argument-kernel compatibility\n",
"* Using [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/)\n",
"* Using [TVM FFI](https://tvm.apache.org/ffi/)\n",
"\n",
"These techniques are complementary and should be used together when applicable\n",
"for an application."
]
},
{
"cell_type": "markdown",
"id": "06495033",
"metadata": {},
"source": [
"### Compile once, run many times\n",
"The `kernel.run` method takes in an optional `compiled_artifact` argument of type\n",
"`cutlass_api.artifact.CompiledArtifact`. When this argument is set, the kernel\n",
"will directly use the precompiled function within `compiled_argument`. When\n",
"it is not set, the call to `kernel.run` will JIT compile the kernel on each\n",
"invocation.\n",
"\n",
"Precompiling the kernel is critical to achieving good performance."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6de11f56",
"metadata": {},
"outputs": [],
"source": [
"stream = torch.cuda.current_stream()\n",
"def no_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream)\n",
"\n",
"# Compile the kernel once, reuse for each iterations\n",
"compiled_artifact = kernel.compile(args[0])\n",
"\n",
"def with_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream, compiled_artifact=compiled_artifact)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "350c9bd6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Without compiled artifact ] avg of 5 iterations: 1.376e+00 seconds\n",
"[With compiled artifact ] avg of 5 iterations: 1.016e-05 seconds\n"
]
}
],
"source": [
"time_no_artifact, _ = benchmark(f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5)\n",
"time_w_artifact, _ = benchmark(f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5)"
]
},
{
"cell_type": "markdown",
"id": "5cfbc2d2",
"metadata": {},
"source": [
"### Bypassing checks for argument-kernel compatibility\n",
"By default, the call to `kernel.run` will check if the kernel supports the provided arguments.\n",
"Under the hood, this invokes `kernel.supports(args)`.\n",
"\n",
"While these checks are helpful for catching incompatible arguments, they are performed\n",
"in Python, and thus can add to host overhead.\n",
"\n",
"When confident that arguments will be compatible with a kernel, one should bypass\n",
"the `supports` check in `kernel.run` by setting the optional `assume_supported_args`\n",
"argument to `True`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5b93dfae",
"metadata": {},
"outputs": [],
"source": [
"def with_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=False)\n",
"\n",
"def without_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b282f437",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[With supports check ] avg of 100 iterations: 1.463e-05 seconds\n",
"[Bypass supports check ] avg of 100 iterations: 6.239e-06 seconds\n",
"Speedup with skip supports: 2.34x\n"
]
}
],
"source": [
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
"time_wo_supports, _ = benchmark(\"Bypass supports check\", without_supports_check)\n",
"print(f\"Speedup with skip supports: {time_w_supports / time_wo_supports:.2f}x\")"
]
},
{
"cell_type": "markdown",
"id": "d74cb3e7",
"metadata": {},
"source": [
"### CUDA Graphs"
]
},
{
"cell_type": "markdown",
"id": "656d5e2c",
"metadata": {},
"source": [
"CUTLASS API supports [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) usage with PyTorch as usual.\n",
"\n",
"The kernel compilation must happen outside the CUDA graph. Then, we create a graph using usual PyTorch idioms to launch a kernel several times on the graph's stream."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e614509f",
"metadata": {},
"outputs": [],
"source": [
"num_launches = 20\n",
"\n",
"# Create a CUDA Graph to run our compiled kernel N times\n",
"g = torch.cuda.CUDAGraph()\n",
"with torch.cuda.graph(g):\n",
" # Run N iterations of our compiled kernel on the current stream\n",
" for i in range(num_launches):\n",
" kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"# Zero the output so we don't refcheck stale results\n",
"_ = outs[0].zero_()"
]
},
{
"cell_type": "markdown",
"id": "8fc69c6e",
"metadata": {},
"source": [
"Once captured, we can replay the graph. This will only replay the kernel launches placed on the CUDA stream.\n",
"Any other prepratory work on the host and arguments passed in from python are cached during the capture."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d9c5d5c5",
"metadata": {},
"outputs": [],
"source": [
"# Replay captured graph and check first result\n",
"g.replay()\n",
"\n",
"torch.testing.assert_close(outs[0], references[0])"
]
},
{
"cell_type": "markdown",
"id": "388c8e02",
"metadata": {},
"source": [
"Let's compare the timing:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "45d4e739",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[20 launches without CUDA Graph] avg of 1 iterations: 4.699e-04 seconds\n",
"[20 launches with CUDA Graph ] avg of 1 iterations: 9.084e-05 seconds\n",
"Speedup with CUDA Graph: 5.17x\n"
]
}
],
"source": [
"def without_cuda_graph(x: int):\n",
" for i in range(num_launches):\n",
" kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"def with_cuda_graph(x: int):\n",
" g.replay()\n",
"\n",
"\n",
"time_wo_cuda_graph, _ = benchmark(f\"{num_launches} launches without CUDA Graph\", without_cuda_graph, warmup_it=0, profiling_it=1)\n",
"time_w_cuda_graph, _ = benchmark(f\"{num_launches} launches with CUDA Graph\", with_cuda_graph, warmup_it=0, profiling_it=1)\n",
"\n",
"print(f\"Speedup with CUDA Graph: {time_wo_cuda_graph / time_w_cuda_graph:.2f}x\")"
]
},
{
"cell_type": "markdown",
"id": "fe5c3168",
"metadata": {},
"source": [
"### TVM FFI"
]
},
{
"cell_type": "markdown",
"id": "ee7f9fd2",
"metadata": {},
"source": [
"When applicable, CUTLASS API uses [Apache TVM FFI](https://tvm.apache.org/ffi/) under the hood for invoking compiled DSL kernels from Python.\n",
"Apache TVM FFI is an open ABI and FFI for machine learning systems.\n",
"\n",
"TVM FFI is enabled by default in CUTLASS API, and is recommended for best performance."
]
},
{
"cell_type": "markdown",
"id": "1690bbed",
"metadata": {},
"source": [
"`cutlass_api.config.GlobalOptions().use_tvm_ffi` controls whether or not TVM-FFI will be used by CUTLASS API."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "993c60ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"print(cutlass_api.config.GlobalOptions().use_tvm_ffi)"
]
},
{
"cell_type": "markdown",
"id": "00ed9a40",
"metadata": {},
"source": [
"If for some reason you do not wish to use it, this section demonstrates how, you can set this to False. No other change is needed. The below code compares the performance with and without TVM-FFI."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e8f56be3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[TVM-FFI ON ] Create args ] avg of 100 iterations: 8.367e-05 seconds\n",
"[[TVM-FFI ON ] Compile kernel ] avg of 5 iterations: 1.352e+00 seconds\n",
"[[TVM-FFI ON ] Run kernel ] avg of 100 iterations: 6.509e-06 seconds\n"
]
}
],
"source": [
"cutlass_api.config.GlobalOptions().use_tvm_ffi = True\n",
"\n",
"def run_iteration(i):\n",
" args = cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
" return kernel.run(\n",
" args,\n",
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"def create_arguments(i: int):\n",
" return cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
"\n",
"args_creation_on, args = benchmark(\"[TVM-FFI ON ] Create args\", create_arguments)\n",
"compilation_on, compiled = benchmark(\"[TVM-FFI ON ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compiled_artifact = compiled[0]\n",
"run_on, _ = benchmark(\"[TVM-FFI ON ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5a4c2db4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[TVM-FFI OFF ] Create args ] avg of 100 iterations: 1.255e-04 seconds\n",
"[[TVM-FFI OFF ] Compile kernel ] avg of 5 iterations: 1.278e+00 seconds\n",
"[[TVM-FFI OFF ] Run kernel ] avg of 100 iterations: 4.519e-05 seconds\n"
]
}
],
"source": [
"cutlass_api.config.GlobalOptions().use_tvm_ffi = False\n",
"args_creation_off, args = benchmark(\"[TVM-FFI OFF ] Create args\", create_arguments)\n",
"compilation_off, compiled = benchmark(\"[TVM-FFI OFF ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compiled_artifact = compiled[0]\n",
"run_off, _ = benchmark(\"[TVM-FFI OFF ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "17b43718",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Speedups with TVM-FFI: \n",
"Arg creation: 1.50x\n",
"Compilation: 0.95x\n",
"Run: 6.94x\n"
]
}
],
"source": [
"print(\"Speedups with TVM-FFI: \")\n",
"print(f\"Arg creation: {args_creation_off / args_creation_on:.2f}x\")\n",
"print(f\"Compilation: {compilation_off / compilation_on:.2f}x\")\n",
"print(f\"Run: {run_off / run_on:.2f}x\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}