2026-01-06 updates

This commit is contained in:
jkosaian
2026-01-06 04:25:33 -08:00
parent dfcb55de16
commit 7c09485e25
77 changed files with 2563 additions and 444 deletions

View File

@@ -31,9 +31,9 @@
"\n",
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 90, 100, 103})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
@@ -67,7 +67,7 @@
"source": [
"M, N, K, L = 128, 256, 64, 2\n",
"ab_type = torch.float16\n",
"out_type = torch.float32\n",
"out_type = torch.float16\n",
"acc_type = torch.float32\n",
"\n",
"A = torch.randint(-1, 2, (L, M, K), device=\"cuda\", dtype=ab_type)\n",
@@ -118,9 +118,9 @@
"metadata": {},
"outputs": [],
"source": [
"kernels = cutlass_api.get_kernels(args)\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert kernels, \"No kernels found for the given arguments!\"\n",
"\n",
"kernel = kernels[0]"
]
},
@@ -148,28 +148,6 @@
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "4d7ad85b",
"metadata": {},
"source": [
"One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid\n",
"JIT compilation on future invocations. Additional details related to this will be\n",
"described below."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06f9f844",
"metadata": {},
"outputs": [],
"source": [
"artifact = kernel.compile(args)\n",
"kernel.run(args, compiled_artifact=artifact)\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "630e9e4b",
@@ -179,7 +157,7 @@
"\n",
"---\n",
"\n",
"### Understanding the core interfaces"
"## Understanding the core interfaces"
]
},
{
@@ -187,7 +165,7 @@
"id": "2d8b8e94",
"metadata": {},
"source": [
"#### 1. `RuntimeArguments` / `GemmArguments`\n",
"### 1. `RuntimeArguments` / `GemmArguments`\n",
"\n",
"`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. \n",
"This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.\n",
@@ -220,7 +198,7 @@
"id": "e7eda0dd",
"metadata": {},
"source": [
"#### 2. Kernel Discovery\n",
"### 2. Kernel Discovery\n",
"\n",
"There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.\n",
"\n",
@@ -254,6 +232,11 @@
"kernels = cutlass_api.get_kernels(args)\n",
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
"\n",
"# we can limit the search to kernels supporting given args + current device compute capability\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
"\n",
"kernel = kernels[0]\n",
"print(f\"Picked kernel with name: {kernel.metadata.kernel_name}\")"
]
@@ -263,7 +246,7 @@
"id": "252a4d38",
"metadata": {},
"source": [
"#### 3. `Kernel` execution"
"### 3. `Kernel` execution"
]
},
{
@@ -286,7 +269,8 @@
"id": "e8945aa6",
"metadata": {},
"source": [
"* `kernel.supports(args)` checks if the kernel supports the given `args`\n",
"#### Verify that the kernel supports the given `args`\n",
"`kernel.supports(args)` checks if the kernel supports the given `args`\n",
" * this is relevant if the kernel was not picked just for these `args`"
]
},
@@ -338,7 +322,9 @@
"id": "c2db8f20",
"metadata": {},
"source": [
"* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
"#### JIT compiling the kernel\n",
"\n",
"`kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
"\n",
"This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).\n",
"\n",
@@ -361,7 +347,8 @@
"id": "4dfb8d51",
"metadata": {},
"source": [
"* `kernel.run(args)` launches the compiled kernel function. This example uses:\n",
"#### Launching the compiled kernel function\n",
"`kernel.run(args)` launches the compiled kernel function. The next example uses:\n",
" * the precompiled artifact\n",
" * a custom stream to launch to\n",
" * bypasses the supports check already performed above (`assume_supported_args=True`)."
@@ -386,6 +373,24 @@
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "7956813d",
"metadata": {},
"source": [
"Passing in a precompiled kernel is critical to achieving good performance because it avoids\n",
"JIT compiling the kernel on each invocation. JIT compilation always occurs when a precompiled\n",
"kernel is not provided in the call to `kernel.run()`."
]
},
{
"cell_type": "markdown",
"id": "c228495a",
"metadata": {},
"source": [
"#### Workspace Buffers"
]
},
{
"cell_type": "markdown",
"id": "f67eeb8f",
@@ -416,7 +421,7 @@
"id": "baffaf12",
"metadata": {},
"source": [
"### Advanced: Filtering on Metadata"
"## Advanced: Filtering on Metadata"
]
},
{

View File

@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "bb450878",
"metadata": {},
"outputs": [],
@@ -20,7 +20,9 @@
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
" sys.exit(0)"
@@ -42,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "e6d77d53",
"metadata": {},
"outputs": [],
@@ -56,13 +58,15 @@
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float16)\n",
"C = torch.randn(L, M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"\n",
"def my_epilogue(accum, C, alpha, beta, extra_scalar):\n",
" Aux = (alpha * accum) + (beta * C)\n",
" D = extra_scalar * Aux\n",
" return D, Aux\n",
"\n",
"\n",
"alpha, beta, extra_scalar = 1.0, 2.0, 0.5\n",
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)\n"
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)"
]
},
{
@@ -77,18 +81,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "f079d9d6",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"from cutlass_api.arguments import GemmArguments, EpilogueArguments\n",
"from cutlass_api.arguments import EpilogueArguments, GemmArguments\n",
"\n",
"# Allocate buffers for D and Aux\n",
"D_, Aux_ = [torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)]\n",
"D_, Aux_ = [\n",
" torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)\n",
"]\n",
"\n",
"epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)\n"
"epi_args = EpilogueArguments(\n",
" my_epilogue,\n",
" C=C,\n",
" alpha=alpha,\n",
" beta=beta,\n",
" extra_scalar=extra_scalar,\n",
" D=D_,\n",
" Aux=Aux_,\n",
")"
]
},
{
@@ -101,14 +115,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "60215c4e",
"metadata": {},
"outputs": [],
"source": [
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args)\n",
"assert len(kernels) > 0\n"
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0"
]
},
{
@@ -122,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "150f3296",
"metadata": {},
"outputs": [],
@@ -130,7 +145,7 @@
"kernels[0].run(args)\n",
"\n",
"torch.testing.assert_close(D, D_)\n",
"torch.testing.assert_close(Aux, Aux_)\n"
"torch.testing.assert_close(Aux, Aux_)"
]
},
{
@@ -289,17 +304,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "171ac178",
"metadata": {},
"outputs": [],
"source": [
"from cutlass_api.fusion.activation import relu\n",
"\n",
"\n",
"def relu_aux_store(accum, alpha, C):\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
" D = relu(F)\n",
" return D, F\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
" D = relu(F)\n",
" return D, F\n",
"\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 3.0\n",
@@ -308,14 +325,14 @@
"\n",
"epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"D_ref, F_ref = relu_aux_store(A @ B, alpha, C)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
"torch.testing.assert_close(F, F_ref)"
]
},
{
@@ -328,15 +345,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "62c2b49b",
"metadata": {},
"outputs": [],
"source": [
"def relu_scale_return_acc(accum, alpha, beta, C, scale):\n",
" F = relu((accum * alpha) + (C * beta))\n",
" D = F * scale\n",
" return D, F, accum\n",
" F = relu((accum * alpha) + (C * beta))\n",
" D = F * scale\n",
" return D, F, accum\n",
"\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 1.0\n",
@@ -346,9 +364,18 @@
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"accum = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float32)\n",
"\n",
"epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)\n",
"epi_args = EpilogueArguments(\n",
" relu_scale_return_acc,\n",
" alpha=alpha,\n",
" beta=beta,\n",
" C=C,\n",
" scale=scale,\n",
" D=D,\n",
" F=F,\n",
" accum=accum,\n",
")\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
@@ -356,7 +383,7 @@
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n",
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))\n"
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))"
]
},
{
@@ -370,7 +397,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "5987bf44",
"metadata": {},
"outputs": [],
@@ -385,7 +412,7 @@
"\n",
"epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
@@ -393,7 +420,7 @@
"D_ref = torch.relu(F_ref)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
"torch.testing.assert_close(F, F_ref)"
]
},
{
@@ -407,54 +434,90 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "1e3d0c89",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accum must be an input to the epilogue function\n"
]
}
],
"source": [
"####################################################\n",
"# Epilogues must take in an accumulator\n",
"####################################################\n",
"def fail_missing_accum(alpha, beta, C):\n",
" D = (C * beta)\n",
" return D\n",
" D = C * beta\n",
" return D\n",
"\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" epi_args = EpilogueArguments(\n",
" fail_missing_accum,\n",
" alpha=alpha,\n",
" beta=beta,\n",
" C=C,\n",
" D=D,\n",
" )\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"accum must be an input to the epilogue function\"\n",
" print(e)\n"
" # \"accum must be an input to the epilogue function\"\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "48a359f7",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output node D is not found in the epilogue function\n"
]
}
],
"source": [
"####################################################\n",
"# Epilogues must return an output named D\n",
"####################################################\n",
"def fail_missing_D(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" return F\n",
" F = (accum * alpha) + (C * beta)\n",
" return F\n",
"\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []\"\n",
" print(e)\n"
" # \"Output node D is not found in the epilogue function\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "49d9ee94",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable 'tmp' cannot be defined twice.\n"
]
}
],
"source": [
"####################################################\n",
"# Epilogues must use single-static assignment (SSA)\n",
@@ -466,51 +529,81 @@
" D = tmp / 4.0\n",
" return D, tmp\n",
"\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"Variable 'tmp' cannot be defined twice.\"\n",
" print(e)\n"
" # \"Variable 'tmp' cannot be defined twice.\"\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "871bb727",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Argument D is not provided in the kwargs of the EpilogueArguments constructor\n",
"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\n"
]
}
],
"source": [
"####################################################\n",
"# Must provide all operands and outputs to\n",
"# EpilogueArguments\n",
"####################################################\n",
"def my_epi(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D\n",
"\n",
"try:\n",
" # Missing D\n",
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" # Missing D\n",
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n",
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n",
"\n",
"try:\n",
" # Missing alpha\n",
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" # Missing alpha\n",
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n"
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,

View File

@@ -24,12 +24,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "5a64b0be",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"from collections.abc import Callable\n",
"\n",
"import cuda.bindings.driver as cuda\n",
"\n",
@@ -110,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "1a2da869",
"metadata": {},
"outputs": [],
@@ -148,18 +148,34 @@
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
" def __init__(self, metadata: KernelMetadata): pass\n",
" def __init__(self, metadata: KernelMetadata):\n",
" pass\n",
"\n",
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
" def _run(\n",
" self,\n",
" args: GemmArguments,\n",
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
" stream,\n",
" workspace=None,\n",
" ):\n",
" pass\n",
"\n",
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
" def compile(\n",
" self, args: GemmArguments, cc: int = None\n",
" ) -> cutlass_api.artifact.CompiledArtifact:\n",
" pass\n",
"\n",
" @staticmethod\n",
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
" def generate_kernels(\n",
" metadata_filter, epilogue_args=None, cc=None\n",
" ) -> list[\"F64GemmKernel\"]:\n",
" pass\n",
"\n",
" def _supports(self, args: GemmArguments) -> Status: pass\n",
" def _supports(self, args: GemmArguments) -> Status:\n",
" pass\n",
"\n",
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
" def get_workspace_size(self, args: GemmArguments) -> int:\n",
" pass"
]
},
{
@@ -175,13 +191,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "785d1882",
"metadata": {},
"outputs": [],
"source": [
"def __init__(self, metadata: KernelMetadata):\n",
" self.metadata = metadata\n",
" # Using Python-2-style super() because we're defining this method outside of the class definition.\n",
" super(F64GemmKernel, self).__init__(metadata)\n",
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
]
@@ -203,12 +220,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "63b4a129",
"metadata": {},
"outputs": [],
"source": [
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
"def compile(\n",
" self, args: GemmArguments, cc: int = None\n",
") -> cutlass_api.artifact.CompiledArtifact:\n",
" stream = cutlass.cute.runtime.make_fake_stream()\n",
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
@@ -227,12 +246,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "2ae7c009",
"metadata": {},
"outputs": [],
"source": [
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
"def _run(\n",
" self,\n",
" args: GemmArguments,\n",
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
" stream,\n",
" workspace=None,\n",
"):\n",
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
" compiled_gemm = artifact.compiled_obj\n",
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
@@ -249,7 +274,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "968906ea",
"metadata": {},
"outputs": [],
@@ -286,7 +311,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "47dc2f20",
"metadata": {},
"outputs": [],
@@ -297,7 +322,6 @@
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
" cc: int = None,\n",
") -> list[\"F64GemmKernel\"]:\n",
"\n",
" # The tile shapes this kernel supports/exposes\n",
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
"\n",
@@ -306,10 +330,12 @@
"\n",
" row_major_stride = (0, 0, 1)\n",
" col_major_stride = (0, 1, 0)\n",
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
" alignment = 1\n",
" stride_combos = list(\n",
" itertools.product([row_major_stride, col_major_stride], repeat=3)\n",
" )\n",
" divisibility = 1\n",
"\n",
" def stride_name(stride): \n",
" def stride_name(stride):\n",
" return \"T\" if stride == row_major_stride else \"N\"\n",
"\n",
" kernels = []\n",
@@ -317,10 +343,18 @@
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
" for stride_A, stride_B, stride_out in stride_combos:\n",
" # Create TensorAttributes for A, B, and out tensors\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
" cutlass.Float64, stride_A, divisibility\n",
" )\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
" cutlass.Float64, stride_B, divisibility\n",
" )\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
" cutlass.Float64, stride_out, divisibility\n",
" )\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",
" stride_A, stride_B, stride_out\n",
" )\n",
"\n",
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
"\n",
@@ -355,24 +389,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "54067d47",
"metadata": {},
"outputs": [],
"source": [
"def _supports(self, args: GemmArguments) -> Status:\n",
" if not (\n",
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
" len(args.A.shape) == 3 # A should be (L, M, K)\n",
" and len(args.B.shape) == 3 # B should be (L, K, N)\n",
" and len(args.out.shape) == 3 # out should be (L, M, N)\n",
" ):\n",
" return Status.fail(\"All operands must be rank 3.\")\n",
" return Status.success()\n"
" return Status.success()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "edaf2cba",
"metadata": {},
"outputs": [],
@@ -404,7 +438,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "cec5431d",
"metadata": {},
"outputs": [],
@@ -420,9 +454,11 @@
"\n",
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
"\n",
"\n",
"def is_f64gemm_kernel(metadata):\n",
" return metadata.kernel_class == F64GemmKernel\n",
"\n",
"\n",
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
]
},
@@ -437,10 +473,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "cdb92b5e",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"F64GemmKernel_tile32x32_ttt\n",
"F64GemmKernel_tile16x16_ttt\n"
]
}
],
"source": [
"print(kernels[0].metadata.kernel_name)\n",
"print(kernels[1].metadata.kernel_name)"
@@ -456,7 +501,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "f5486244",
"metadata": {},
"outputs": [],
@@ -478,17 +523,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"id": "917c74e3",
"metadata": {},
"outputs": [],
"source": [
"def my_filter(metadata):\n",
" return (\n",
" is_f64gemm_kernel(metadata) and\n",
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
" metadata.design.tile_shape[0] == 256\n",
" is_f64gemm_kernel(metadata)\n",
" and isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata)\n",
" and metadata.design.tile_shape[0] == 256\n",
" )\n",
"\n",
"\n",
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
"\n",
"# No kernels should be found\n",
@@ -539,8 +586,22 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,

View File

@@ -17,12 +17,15 @@
"This notebook focuses on the latter: techniques to minimize any overheads incurred from the CUTLASS API and underlying\n",
"DSL runtimes.\n",
"\n",
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic."
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic.\n",
"\n",
"**Note**: Latency measurements can vary from system to system. You may see different results on your system than shown\n",
"in the pre-populated fields of this notebook."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "e3ca9e40",
"metadata": {},
"outputs": [],
@@ -34,13 +37,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "efaac09c",
"metadata": {},
"outputs": [],
"source": [
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
" sys.exit(0)"
@@ -61,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "b8c44947",
"metadata": {},
"outputs": [],
@@ -76,19 +81,34 @@
"# We use different operands in each iteration. Though not particularly relevant for\n",
"# host latency, this is a best practice when benchmarking GPU kernels to avoid\n",
"# unrealistic caching effects.\n",
"As = [torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"Bs = [torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"outs = [torch.empty((M, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"As = [\n",
" torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16)\n",
" for _ in range(total_iterations)\n",
"]\n",
"Bs = [\n",
" torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16)\n",
" for _ in range(total_iterations)\n",
"]\n",
"outs = [\n",
" torch.empty((M, N), device=\"cuda\", dtype=torch.float16)\n",
" for _ in range(total_iterations)\n",
"]\n",
"\n",
"# Construct arguments outside of the benchmarking loop. We will later also consider\n",
"# cases in which they are constructed inside the benchmarking loop.\n",
"args = [cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16) for i in range(total_iterations)]\n",
"args = [\n",
" cutlass_api.arguments.GemmArguments(\n",
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float32\n",
" )\n",
" for i in range(total_iterations)\n",
"]\n",
"\n",
"references = [(As[i] @ Bs[i]).to(outs[i].dtype) for i in range(total_iterations)]\n",
"\n",
"kernels = cutlass_api.get_kernels(args[0], cc=100)\n",
"\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args[0], cc=cc)\n",
"assert len(kernels) > 0\n",
"\n",
"kernel = kernels[0]"
]
},
@@ -102,14 +122,18 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "2472eafa",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations):\n",
"def benchmark(\n",
" label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations\n",
"):\n",
" total_it = warmup_it + profiling_it\n",
" assert total_it <= total_iterations, f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
" assert total_it <= total_iterations, (\n",
" f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
" )\n",
" # warmup\n",
" rets = [None] * total_it\n",
" for i in range(warmup_it):\n",
@@ -151,7 +175,7 @@
"### Compile once, run many times\n",
"The `kernel.run` method takes in an optional `compiled_artifact` argument of type\n",
"`cutlass_api.artifact.CompiledArtifact`. When this argument is set, the kernel\n",
"will directly use the precompiled function within `compiled_argument`. When\n",
"will directly use the precompiled function within `compiled_artifact`. When\n",
"it is not set, the call to `kernel.run` will JIT compile the kernel on each\n",
"invocation.\n",
"\n",
@@ -160,25 +184,29 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "6de11f56",
"metadata": {},
"outputs": [],
"source": [
"stream = torch.cuda.current_stream()\n",
"\n",
"\n",
"def no_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream)\n",
"\n",
"\n",
"# Compile the kernel once, reuse for each iterations\n",
"compiled_artifact = kernel.compile(args[0])\n",
"\n",
"\n",
"def with_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream, compiled_artifact=compiled_artifact)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "350c9bd6",
"metadata": {},
"outputs": [
@@ -192,8 +220,12 @@
}
],
"source": [
"time_no_artifact, _ = benchmark(f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5)\n",
"time_w_artifact, _ = benchmark(f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5)"
"time_no_artifact, _ = benchmark(\n",
" f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5\n",
")\n",
"time_w_artifact, _ = benchmark(\n",
" f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5\n",
")"
]
},
{
@@ -215,21 +247,32 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "5b93dfae",
"metadata": {},
"outputs": [],
"source": [
"def with_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=False)\n",
" return kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=stream,\n",
" assume_supported_args=False,\n",
" )\n",
"\n",
"\n",
"def without_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=True)"
" return kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=stream,\n",
" assume_supported_args=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "b282f437",
"metadata": {},
"outputs": [
@@ -244,7 +287,7 @@
}
],
"source": [
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
"time_wo_supports, _ = benchmark(\"Bypass supports check\", without_supports_check)\n",
"print(f\"Speedup with skip supports: {time_w_supports / time_wo_supports:.2f}x\")"
]
@@ -262,14 +305,16 @@
"id": "656d5e2c",
"metadata": {},
"source": [
"CUTLASS API supports [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) usage with PyTorch as usual.\n",
"[CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) allow a sequence of GPU operations to be defined as a dependency graph and then launched as a single unit, significantly reducing CPU launch overhead and enabling whole-graph optimizations.\n",
"\n",
"CUTLASS API supports CUDA Graphs usage with PyTorch as usual.\n",
"\n",
"The kernel compilation must happen outside the CUDA graph. Then, we create a graph using usual PyTorch idioms to launch a kernel several times on the graph's stream."
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "e614509f",
"metadata": {},
"outputs": [],
@@ -279,6 +324,10 @@
"# Create a CUDA Graph to run our compiled kernel N times\n",
"g = torch.cuda.CUDAGraph()\n",
"with torch.cuda.graph(g):\n",
"\n",
" ### NOTE! Kernel compilation must happen outside the graph\n",
" ### kernel.compile(args)\n",
"\n",
" # Run N iterations of our compiled kernel on the current stream\n",
" for i in range(num_launches):\n",
" kernel.run(\n",
@@ -286,10 +335,7 @@
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"# Zero the output so we don't refcheck stale results\n",
"_ = outs[0].zero_()"
" )"
]
},
{
@@ -297,8 +343,12 @@
"id": "8fc69c6e",
"metadata": {},
"source": [
"Once captured, we can replay the graph. This will only replay the kernel launches placed on the CUDA stream.\n",
"Any other prepratory work on the host and arguments passed in from python are cached during the capture."
"This records/captures all the kernel launches to the CUDA Stream associated with the graph `g`, without actually launching them.\n",
"Once captured, we can replay the graph.\n",
"\n",
"Note that graph replay will only replay the kernel launches placed on the graph's stream\n",
"* During graph capture, we must be careful to capture to the correct stream (`torch.cuda.current_stream()` under the graph context)\n",
"* Any other preparatory work on the host and arguments passed in from Python are cached during the capture. Changing them would require re-capturing the graph"
]
},
{
@@ -324,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "45d4e739",
"metadata": {},
"outputs": [
@@ -348,12 +398,23 @@
" assume_supported_args=True,\n",
" )\n",
"\n",
"\n",
"def with_cuda_graph(x: int):\n",
" g.replay()\n",
"\n",
"\n",
"time_wo_cuda_graph, _ = benchmark(f\"{num_launches} launches without CUDA Graph\", without_cuda_graph, warmup_it=0, profiling_it=1)\n",
"time_w_cuda_graph, _ = benchmark(f\"{num_launches} launches with CUDA Graph\", with_cuda_graph, warmup_it=0, profiling_it=1)\n",
"time_wo_cuda_graph, _ = benchmark(\n",
" f\"{num_launches} launches without CUDA Graph\",\n",
" without_cuda_graph,\n",
" warmup_it=0,\n",
" profiling_it=1,\n",
")\n",
"time_w_cuda_graph, _ = benchmark(\n",
" f\"{num_launches} launches with CUDA Graph\",\n",
" with_cuda_graph,\n",
" warmup_it=0,\n",
" profiling_it=1,\n",
")\n",
"\n",
"print(f\"Speedup with CUDA Graph: {time_wo_cuda_graph / time_w_cuda_graph:.2f}x\")"
]
@@ -371,8 +432,8 @@
"id": "ee7f9fd2",
"metadata": {},
"source": [
"When applicable, CUTLASS API uses [Apache TVM FFI](https://tvm.apache.org/ffi/) under the hood for invoking compiled DSL kernels from Python.\n",
"Apache TVM FFI is an open ABI and FFI for machine learning systems.\n",
"[Apache TVM FFI](https://tvm.apache.org/ffi/) is an open ABI and FFI for machine learning systems.\n",
"When available, CUTLASS API uses Apache TVM-FFI under the hood as its interface for invoking compiled DSL kernels from Python.\n",
"\n",
"TVM FFI is enabled by default in CUTLASS API, and is recommended for best performance."
]
@@ -413,7 +474,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "e8f56be3",
"metadata": {},
"outputs": [
@@ -428,10 +489,15 @@
}
],
"source": [
"original_use_tvm_ffi = cutlass_api.config.GlobalOptions().use_tvm_ffi\n",
"\n",
"cutlass_api.config.GlobalOptions().use_tvm_ffi = True\n",
"\n",
"\n",
"def run_iteration(i):\n",
" args = cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
" args = cutlass_api.arguments.GemmArguments(\n",
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
" )\n",
" return kernel.run(\n",
" args,\n",
" compiled_artifact=compiled_artifact,\n",
@@ -439,18 +505,35 @@
" assume_supported_args=True,\n",
" )\n",
"\n",
"\n",
"def create_arguments(i: int):\n",
" return cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
" return cutlass_api.arguments.GemmArguments(\n",
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
" )\n",
"\n",
"\n",
"args_creation_on, args = benchmark(\"[TVM-FFI ON ] Create args\", create_arguments)\n",
"compilation_on, compiled = benchmark(\"[TVM-FFI ON ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compilation_on, compiled = benchmark(\n",
" \"[TVM-FFI ON ] Compile kernel\",\n",
" lambda i: kernel.compile(args[i]),\n",
" warmup_it=2,\n",
" profiling_it=5,\n",
")\n",
"compiled_artifact = compiled[0]\n",
"run_on, _ = benchmark(\"[TVM-FFI ON ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
"run_on, _ = benchmark(\n",
" \"[TVM-FFI ON ] Run kernel\",\n",
" lambda i: kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" assume_supported_args=True,\n",
" stream=stream,\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "5a4c2db4",
"metadata": {},
"outputs": [
@@ -467,9 +550,25 @@
"source": [
"cutlass_api.config.GlobalOptions().use_tvm_ffi = False\n",
"args_creation_off, args = benchmark(\"[TVM-FFI OFF ] Create args\", create_arguments)\n",
"compilation_off, compiled = benchmark(\"[TVM-FFI OFF ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compilation_off, compiled = benchmark(\n",
" \"[TVM-FFI OFF ] Compile kernel\",\n",
" lambda i: kernel.compile(args[i]),\n",
" warmup_it=2,\n",
" profiling_it=5,\n",
")\n",
"compiled_artifact = compiled[0]\n",
"run_off, _ = benchmark(\"[TVM-FFI OFF ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
"run_off, _ = benchmark(\n",
" \"[TVM-FFI OFF ] Run kernel\",\n",
" lambda i: kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" assume_supported_args=True,\n",
" stream=stream,\n",
" ),\n",
")\n",
"\n",
"# Restore original setting\n",
"cutlass_api.config.GlobalOptions().use_tvm_ffi = original_use_tvm_ffi"
]
},
{

View File

@@ -0,0 +1,166 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4620d513",
"metadata": {},
"source": [
"# Using fake tensors with the CUTLASS API\n",
"Fake tensors (e.g., [torch's FakeTensor](https://docs.pytorch.org/docs/2.8/torch.compiler_fake_tensor.html))\n",
"are useful for describing the properties of a tensor without actually allocating backing data.\n",
"\n",
"This example shows how fake tensors can be used within the CUTLASS API\n",
"for discovering and compiling a GEMM kernel."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d231b32e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import cutlass_api\n",
"\n",
"torch.manual_seed(2025)\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\")\n",
" import sys\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "f7af2d90",
"metadata": {},
"source": [
"We first set up operands `A`, `B`, and `out` in torch's `FakeTensorMode`.\n",
"These will have all the properties needed for CUTLASS API to construct\n",
"the internal representations of tensors used for discovering and compiling\n",
"kernels."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9426b66f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FakeTensor(..., device='cuda:0', size=(128, 512), dtype=torch.float16)\n",
"FakeTensor(..., device='cuda:0', size=(512, 256), dtype=torch.float16)\n",
"FakeTensor(..., device='cuda:0', size=(128, 256), dtype=torch.float16)\n"
]
}
],
"source": [
"M, N, K = 128, 256, 512\n",
"\n",
"with torch._subclasses.fake_tensor.FakeTensorMode():\n",
" A_fake = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
" B_fake = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
" out_fake = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"print(A_fake)\n",
"print(B_fake)\n",
"print(out_fake)"
]
},
{
"cell_type": "markdown",
"id": "4f540c78",
"metadata": {},
"source": [
"We can now use these fake tensors to create `GemmArguments`, and use\n",
"these to discover and compile a compatible kernel. Note that the same APIs are\n",
"used in creating `GemmArguments` as would be used if using\n",
"\"real\" tensors."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e32b700d",
"metadata": {},
"outputs": [],
"source": [
"args_fake = cutlass_api.arguments.GemmArguments(\n",
" A_fake, B_fake, out_fake, accumulator_type=torch.float32)\n",
"\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args_fake, cc=cc)\n",
"assert len(kernels) > 0\n",
"\n",
"kernel = kernels[0]\n",
"compiled_artifact = kernel.compile(args_fake)"
]
},
{
"cell_type": "markdown",
"id": "07fff511",
"metadata": {},
"source": [
"The `kernel` and `compiled_artifact` discovered using fake tensors\n",
"above can now used for running the kernel using real tensors."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b3034bf1",
"metadata": {},
"outputs": [],
"source": [
"# Create real tensors\n",
"A_real = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
"B_real = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
"out_real = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"args_real = cutlass_api.arguments.GemmArguments(\n",
" A_real, B_real, out_real, accumulator_type=torch.float32)\n",
"\n",
"# Run the kernel using the compiled_artifact from resulting\n",
"# from compiling with fake tensors.\n",
"kernel.run(args_real, compiled_artifact)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "09871eca",
"metadata": {},
"outputs": [],
"source": [
"ref = A_real @ B_real\n",
"torch.testing.assert_close(out_real, ref)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}