mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-17 19:38:58 +00:00
2026-01-06 updates
This commit is contained in:
@@ -31,9 +31,9 @@
|
||||
"\n",
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 90, 100, 103})):\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
|
||||
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
@@ -67,7 +67,7 @@
|
||||
"source": [
|
||||
"M, N, K, L = 128, 256, 64, 2\n",
|
||||
"ab_type = torch.float16\n",
|
||||
"out_type = torch.float32\n",
|
||||
"out_type = torch.float16\n",
|
||||
"acc_type = torch.float32\n",
|
||||
"\n",
|
||||
"A = torch.randint(-1, 2, (L, M, K), device=\"cuda\", dtype=ab_type)\n",
|
||||
@@ -118,9 +118,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert kernels, \"No kernels found for the given arguments!\"\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]"
|
||||
]
|
||||
},
|
||||
@@ -148,28 +148,6 @@
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d7ad85b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid\n",
|
||||
"JIT compilation on future invocations. Additional details related to this will be\n",
|
||||
"described below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "06f9f844",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"artifact = kernel.compile(args)\n",
|
||||
"kernel.run(args, compiled_artifact=artifact)\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "630e9e4b",
|
||||
@@ -179,7 +157,7 @@
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Understanding the core interfaces"
|
||||
"## Understanding the core interfaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -187,7 +165,7 @@
|
||||
"id": "2d8b8e94",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 1. `RuntimeArguments` / `GemmArguments`\n",
|
||||
"### 1. `RuntimeArguments` / `GemmArguments`\n",
|
||||
"\n",
|
||||
"`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. \n",
|
||||
"This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.\n",
|
||||
@@ -220,7 +198,7 @@
|
||||
"id": "e7eda0dd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 2. Kernel Discovery\n",
|
||||
"### 2. Kernel Discovery\n",
|
||||
"\n",
|
||||
"There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.\n",
|
||||
"\n",
|
||||
@@ -254,6 +232,11 @@
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
|
||||
"\n",
|
||||
"# we can limit the search to kernels supporting given args + current device compute capability\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]\n",
|
||||
"print(f\"Picked kernel with name: {kernel.metadata.kernel_name}\")"
|
||||
]
|
||||
@@ -263,7 +246,7 @@
|
||||
"id": "252a4d38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 3. `Kernel` execution"
|
||||
"### 3. `Kernel` execution"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -286,7 +269,8 @@
|
||||
"id": "e8945aa6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.supports(args)` checks if the kernel supports the given `args`\n",
|
||||
"#### Verify that the kernel supports the given `args`\n",
|
||||
"`kernel.supports(args)` checks if the kernel supports the given `args`\n",
|
||||
" * this is relevant if the kernel was not picked just for these `args`"
|
||||
]
|
||||
},
|
||||
@@ -338,7 +322,9 @@
|
||||
"id": "c2db8f20",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
|
||||
"#### JIT compiling the kernel\n",
|
||||
"\n",
|
||||
"`kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
|
||||
"\n",
|
||||
"This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).\n",
|
||||
"\n",
|
||||
@@ -361,7 +347,8 @@
|
||||
"id": "4dfb8d51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.run(args)` launches the compiled kernel function. This example uses:\n",
|
||||
"#### Launching the compiled kernel function\n",
|
||||
"`kernel.run(args)` launches the compiled kernel function. The next example uses:\n",
|
||||
" * the precompiled artifact\n",
|
||||
" * a custom stream to launch to\n",
|
||||
" * bypasses the supports check already performed above (`assume_supported_args=True`)."
|
||||
@@ -386,6 +373,24 @@
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7956813d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Passing in a precompiled kernel is critical to achieving good performance because it avoids\n",
|
||||
"JIT compiling the kernel on each invocation. JIT compilation always occurs when a precompiled\n",
|
||||
"kernel is not provided in the call to `kernel.run()`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c228495a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Workspace Buffers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f67eeb8f",
|
||||
@@ -416,7 +421,7 @@
|
||||
"id": "baffaf12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Advanced: Filtering on Metadata"
|
||||
"## Advanced: Filtering on Metadata"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "bb450878",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -20,7 +20,9 @@
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
" sys.exit(0)"
|
||||
@@ -42,7 +44,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"id": "e6d77d53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -56,13 +58,15 @@
|
||||
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"C = torch.randn(L, M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def my_epilogue(accum, C, alpha, beta, extra_scalar):\n",
|
||||
" Aux = (alpha * accum) + (beta * C)\n",
|
||||
" D = extra_scalar * Aux\n",
|
||||
" return D, Aux\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"alpha, beta, extra_scalar = 1.0, 2.0, 0.5\n",
|
||||
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)\n"
|
||||
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -77,18 +81,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"id": "f079d9d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass_api\n",
|
||||
"from cutlass_api.arguments import GemmArguments, EpilogueArguments\n",
|
||||
"from cutlass_api.arguments import EpilogueArguments, GemmArguments\n",
|
||||
"\n",
|
||||
"# Allocate buffers for D and Aux\n",
|
||||
"D_, Aux_ = [torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)]\n",
|
||||
"D_, Aux_ = [\n",
|
||||
" torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)\n"
|
||||
"epi_args = EpilogueArguments(\n",
|
||||
" my_epilogue,\n",
|
||||
" C=C,\n",
|
||||
" alpha=alpha,\n",
|
||||
" beta=beta,\n",
|
||||
" extra_scalar=extra_scalar,\n",
|
||||
" D=D_,\n",
|
||||
" Aux=Aux_,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -101,14 +115,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"id": "60215c4e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"assert len(kernels) > 0\n"
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -122,7 +137,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"id": "150f3296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -130,7 +145,7 @@
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_)\n",
|
||||
"torch.testing.assert_close(Aux, Aux_)\n"
|
||||
"torch.testing.assert_close(Aux, Aux_)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -289,17 +304,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"id": "171ac178",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass_api.fusion.activation import relu\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def relu_aux_store(accum, alpha, C):\n",
|
||||
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
|
||||
" D = relu(F)\n",
|
||||
" return D, F\n",
|
||||
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
|
||||
" D = relu(F)\n",
|
||||
" return D, F\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"alpha = 3.0\n",
|
||||
@@ -308,14 +325,14 @@
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"D_ref, F_ref = relu_aux_store(A @ B, alpha, C)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n"
|
||||
"torch.testing.assert_close(F, F_ref)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -328,15 +345,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"id": "62c2b49b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def relu_scale_return_acc(accum, alpha, beta, C, scale):\n",
|
||||
" F = relu((accum * alpha) + (C * beta))\n",
|
||||
" D = F * scale\n",
|
||||
" return D, F, accum\n",
|
||||
" F = relu((accum * alpha) + (C * beta))\n",
|
||||
" D = F * scale\n",
|
||||
" return D, F, accum\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"alpha = 1.0\n",
|
||||
@@ -346,9 +364,18 @@
|
||||
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"accum = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float32)\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)\n",
|
||||
"epi_args = EpilogueArguments(\n",
|
||||
" relu_scale_return_acc,\n",
|
||||
" alpha=alpha,\n",
|
||||
" beta=beta,\n",
|
||||
" C=C,\n",
|
||||
" scale=scale,\n",
|
||||
" D=D,\n",
|
||||
" F=F,\n",
|
||||
" accum=accum,\n",
|
||||
")\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
@@ -356,7 +383,7 @@
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n",
|
||||
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))\n"
|
||||
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -370,7 +397,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"id": "5987bf44",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -385,7 +412,7 @@
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
@@ -393,7 +420,7 @@
|
||||
"D_ref = torch.relu(F_ref)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n"
|
||||
"torch.testing.assert_close(F, F_ref)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -407,54 +434,90 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "1e3d0c89",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"accum must be an input to the epilogue function\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must take in an accumulator\n",
|
||||
"####################################################\n",
|
||||
"def fail_missing_accum(alpha, beta, C):\n",
|
||||
" D = (C * beta)\n",
|
||||
" return D\n",
|
||||
" D = C * beta\n",
|
||||
" return D\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" epi_args = EpilogueArguments(\n",
|
||||
" fail_missing_accum,\n",
|
||||
" alpha=alpha,\n",
|
||||
" beta=beta,\n",
|
||||
" C=C,\n",
|
||||
" D=D,\n",
|
||||
" )\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"accum must be an input to the epilogue function\"\n",
|
||||
" print(e)\n"
|
||||
" # \"accum must be an input to the epilogue function\"\n",
|
||||
" print(e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"id": "48a359f7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output node D is not found in the epilogue function\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must return an output named D\n",
|
||||
"####################################################\n",
|
||||
"def fail_missing_D(accum, alpha, beta, C):\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" return F\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" return F\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []\"\n",
|
||||
" print(e)\n"
|
||||
" # \"Output node D is not found in the epilogue function\n",
|
||||
" print(e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"id": "49d9ee94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Variable 'tmp' cannot be defined twice.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must use single-static assignment (SSA)\n",
|
||||
@@ -466,51 +529,81 @@
|
||||
" D = tmp / 4.0\n",
|
||||
" return D, tmp\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Variable 'tmp' cannot be defined twice.\"\n",
|
||||
" print(e)\n"
|
||||
" # \"Variable 'tmp' cannot be defined twice.\"\n",
|
||||
" print(e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"id": "871bb727",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Argument D is not provided in the kwargs of the EpilogueArguments constructor\n",
|
||||
"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Must provide all operands and outputs to\n",
|
||||
"# EpilogueArguments\n",
|
||||
"####################################################\n",
|
||||
"def my_epi(accum, alpha, beta, C):\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" D = relu(F)\n",
|
||||
" return D\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" D = relu(F)\n",
|
||||
" return D\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" # Missing D\n",
|
||||
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" # Missing D\n",
|
||||
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n",
|
||||
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" # Missing alpha\n",
|
||||
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" # Missing alpha\n",
|
||||
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n"
|
||||
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -24,12 +24,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "5a64b0be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Callable\n",
|
||||
"from collections.abc import Callable\n",
|
||||
"\n",
|
||||
"import cuda.bindings.driver as cuda\n",
|
||||
"\n",
|
||||
@@ -110,7 +110,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"id": "1a2da869",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -148,18 +148,34 @@
|
||||
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
|
||||
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
|
||||
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
|
||||
" def __init__(self, metadata: KernelMetadata): pass\n",
|
||||
" def __init__(self, metadata: KernelMetadata):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
|
||||
" def _run(\n",
|
||||
" self,\n",
|
||||
" args: GemmArguments,\n",
|
||||
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
|
||||
" stream,\n",
|
||||
" workspace=None,\n",
|
||||
" ):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
|
||||
" def compile(\n",
|
||||
" self, args: GemmArguments, cc: int = None\n",
|
||||
" ) -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
|
||||
" def generate_kernels(\n",
|
||||
" metadata_filter, epilogue_args=None, cc=None\n",
|
||||
" ) -> list[\"F64GemmKernel\"]:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def _supports(self, args: GemmArguments) -> Status: pass\n",
|
||||
" def _supports(self, args: GemmArguments) -> Status:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
|
||||
" def get_workspace_size(self, args: GemmArguments) -> int:\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -175,13 +191,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"id": "785d1882",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def __init__(self, metadata: KernelMetadata):\n",
|
||||
" self.metadata = metadata\n",
|
||||
" # Using Python-2-style super() because we're defining this method outside of the class definition.\n",
|
||||
" super(F64GemmKernel, self).__init__(metadata)\n",
|
||||
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
|
||||
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
|
||||
]
|
||||
@@ -203,12 +220,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"id": "63b4a129",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
"def compile(\n",
|
||||
" self, args: GemmArguments, cc: int = None\n",
|
||||
") -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
" stream = cutlass.cute.runtime.make_fake_stream()\n",
|
||||
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
|
||||
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
|
||||
@@ -227,12 +246,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"id": "2ae7c009",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
|
||||
"def _run(\n",
|
||||
" self,\n",
|
||||
" args: GemmArguments,\n",
|
||||
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
|
||||
" stream,\n",
|
||||
" workspace=None,\n",
|
||||
"):\n",
|
||||
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
|
||||
" compiled_gemm = artifact.compiled_obj\n",
|
||||
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
|
||||
@@ -249,7 +274,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"id": "968906ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -286,7 +311,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"id": "47dc2f20",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -297,7 +322,6 @@
|
||||
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
|
||||
" cc: int = None,\n",
|
||||
") -> list[\"F64GemmKernel\"]:\n",
|
||||
"\n",
|
||||
" # The tile shapes this kernel supports/exposes\n",
|
||||
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
|
||||
"\n",
|
||||
@@ -306,10 +330,12 @@
|
||||
"\n",
|
||||
" row_major_stride = (0, 0, 1)\n",
|
||||
" col_major_stride = (0, 1, 0)\n",
|
||||
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
|
||||
" alignment = 1\n",
|
||||
" stride_combos = list(\n",
|
||||
" itertools.product([row_major_stride, col_major_stride], repeat=3)\n",
|
||||
" )\n",
|
||||
" divisibility = 1\n",
|
||||
"\n",
|
||||
" def stride_name(stride): \n",
|
||||
" def stride_name(stride):\n",
|
||||
" return \"T\" if stride == row_major_stride else \"N\"\n",
|
||||
"\n",
|
||||
" kernels = []\n",
|
||||
@@ -317,10 +343,18 @@
|
||||
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
|
||||
" for stride_A, stride_B, stride_out in stride_combos:\n",
|
||||
" # Create TensorAttributes for A, B, and out tensors\n",
|
||||
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
|
||||
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
|
||||
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
|
||||
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
|
||||
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" cutlass.Float64, stride_A, divisibility\n",
|
||||
" )\n",
|
||||
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" cutlass.Float64, stride_B, divisibility\n",
|
||||
" )\n",
|
||||
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" cutlass.Float64, stride_out, divisibility\n",
|
||||
" )\n",
|
||||
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",
|
||||
" stride_A, stride_B, stride_out\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
|
||||
"\n",
|
||||
@@ -355,24 +389,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "54067d47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _supports(self, args: GemmArguments) -> Status:\n",
|
||||
" if not (\n",
|
||||
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
|
||||
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
|
||||
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
|
||||
" len(args.A.shape) == 3 # A should be (L, M, K)\n",
|
||||
" and len(args.B.shape) == 3 # B should be (L, K, N)\n",
|
||||
" and len(args.out.shape) == 3 # out should be (L, M, N)\n",
|
||||
" ):\n",
|
||||
" return Status.fail(\"All operands must be rank 3.\")\n",
|
||||
" return Status.success()\n"
|
||||
" return Status.success()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"id": "edaf2cba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -404,7 +438,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"id": "cec5431d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -420,9 +454,11 @@
|
||||
"\n",
|
||||
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def is_f64gemm_kernel(metadata):\n",
|
||||
" return metadata.kernel_class == F64GemmKernel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
|
||||
]
|
||||
},
|
||||
@@ -437,10 +473,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"id": "cdb92b5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"F64GemmKernel_tile32x32_ttt\n",
|
||||
"F64GemmKernel_tile16x16_ttt\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(kernels[0].metadata.kernel_name)\n",
|
||||
"print(kernels[1].metadata.kernel_name)"
|
||||
@@ -456,7 +501,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"id": "f5486244",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -478,17 +523,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"id": "917c74e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def my_filter(metadata):\n",
|
||||
" return (\n",
|
||||
" is_f64gemm_kernel(metadata) and\n",
|
||||
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
|
||||
" metadata.design.tile_shape[0] == 256\n",
|
||||
" is_f64gemm_kernel(metadata)\n",
|
||||
" and isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata)\n",
|
||||
" and metadata.design.tile_shape[0] == 256\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
|
||||
"\n",
|
||||
"# No kernels should be found\n",
|
||||
@@ -539,8 +586,22 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -17,12 +17,15 @@
|
||||
"This notebook focuses on the latter: techniques to minimize any overheads incurred from the CUTLASS API and underlying\n",
|
||||
"DSL runtimes.\n",
|
||||
"\n",
|
||||
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic."
|
||||
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic.\n",
|
||||
"\n",
|
||||
"**Note**: Latency measurements can vary from system to system. You may see different results on your system than shown\n",
|
||||
"in the pre-populated fields of this notebook."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "e3ca9e40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -34,13 +37,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"id": "efaac09c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
" sys.exit(0)"
|
||||
@@ -61,7 +66,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "b8c44947",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -76,19 +81,34 @@
|
||||
"# We use different operands in each iteration. Though not particularly relevant for\n",
|
||||
"# host latency, this is a best practice when benchmarking GPU kernels to avoid\n",
|
||||
"# unrealistic caching effects.\n",
|
||||
"As = [torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"Bs = [torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"outs = [torch.empty((M, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"As = [\n",
|
||||
" torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16)\n",
|
||||
" for _ in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"Bs = [\n",
|
||||
" torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
" for _ in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"outs = [\n",
|
||||
" torch.empty((M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
" for _ in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Construct arguments outside of the benchmarking loop. We will later also consider\n",
|
||||
"# cases in which they are constructed inside the benchmarking loop.\n",
|
||||
"args = [cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16) for i in range(total_iterations)]\n",
|
||||
"args = [\n",
|
||||
" cutlass_api.arguments.GemmArguments(\n",
|
||||
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float32\n",
|
||||
" )\n",
|
||||
" for i in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"references = [(As[i] @ Bs[i]).to(outs[i].dtype) for i in range(total_iterations)]\n",
|
||||
"\n",
|
||||
"kernels = cutlass_api.get_kernels(args[0], cc=100)\n",
|
||||
"\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args[0], cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]"
|
||||
]
|
||||
},
|
||||
@@ -102,14 +122,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "2472eafa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations):\n",
|
||||
"def benchmark(\n",
|
||||
" label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations\n",
|
||||
"):\n",
|
||||
" total_it = warmup_it + profiling_it\n",
|
||||
" assert total_it <= total_iterations, f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
|
||||
" assert total_it <= total_iterations, (\n",
|
||||
" f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
|
||||
" )\n",
|
||||
" # warmup\n",
|
||||
" rets = [None] * total_it\n",
|
||||
" for i in range(warmup_it):\n",
|
||||
@@ -151,7 +175,7 @@
|
||||
"### Compile once, run many times\n",
|
||||
"The `kernel.run` method takes in an optional `compiled_artifact` argument of type\n",
|
||||
"`cutlass_api.artifact.CompiledArtifact`. When this argument is set, the kernel\n",
|
||||
"will directly use the precompiled function within `compiled_argument`. When\n",
|
||||
"will directly use the precompiled function within `compiled_artifact`. When\n",
|
||||
"it is not set, the call to `kernel.run` will JIT compile the kernel on each\n",
|
||||
"invocation.\n",
|
||||
"\n",
|
||||
@@ -160,25 +184,29 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"id": "6de11f56",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"stream = torch.cuda.current_stream()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def no_compiled_artifact(i: int):\n",
|
||||
" return kernel.run(args[i], stream=stream)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compile the kernel once, reuse for each iterations\n",
|
||||
"compiled_artifact = kernel.compile(args[0])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def with_compiled_artifact(i: int):\n",
|
||||
" return kernel.run(args[i], stream=stream, compiled_artifact=compiled_artifact)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"id": "350c9bd6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -192,8 +220,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"time_no_artifact, _ = benchmark(f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5)\n",
|
||||
"time_w_artifact, _ = benchmark(f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5)"
|
||||
"time_no_artifact, _ = benchmark(\n",
|
||||
" f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5\n",
|
||||
")\n",
|
||||
"time_w_artifact, _ = benchmark(\n",
|
||||
" f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -215,21 +247,32 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "5b93dfae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def with_supports_check(i: int):\n",
|
||||
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=False)\n",
|
||||
" return kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=stream,\n",
|
||||
" assume_supported_args=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def without_supports_check(i: int):\n",
|
||||
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=True)"
|
||||
" return kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=stream,\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"id": "b282f437",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -244,7 +287,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
|
||||
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
|
||||
"time_wo_supports, _ = benchmark(\"Bypass supports check\", without_supports_check)\n",
|
||||
"print(f\"Speedup with skip supports: {time_w_supports / time_wo_supports:.2f}x\")"
|
||||
]
|
||||
@@ -262,14 +305,16 @@
|
||||
"id": "656d5e2c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"CUTLASS API supports [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) usage with PyTorch as usual.\n",
|
||||
"[CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) allow a sequence of GPU operations to be defined as a dependency graph and then launched as a single unit, significantly reducing CPU launch overhead and enabling whole-graph optimizations.\n",
|
||||
"\n",
|
||||
"CUTLASS API supports CUDA Graphs usage with PyTorch as usual.\n",
|
||||
"\n",
|
||||
"The kernel compilation must happen outside the CUDA graph. Then, we create a graph using usual PyTorch idioms to launch a kernel several times on the graph's stream."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"id": "e614509f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -279,6 +324,10 @@
|
||||
"# Create a CUDA Graph to run our compiled kernel N times\n",
|
||||
"g = torch.cuda.CUDAGraph()\n",
|
||||
"with torch.cuda.graph(g):\n",
|
||||
"\n",
|
||||
" ### NOTE! Kernel compilation must happen outside the graph\n",
|
||||
" ### kernel.compile(args)\n",
|
||||
"\n",
|
||||
" # Run N iterations of our compiled kernel on the current stream\n",
|
||||
" for i in range(num_launches):\n",
|
||||
" kernel.run(\n",
|
||||
@@ -286,10 +335,7 @@
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=torch.cuda.current_stream(),\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Zero the output so we don't refcheck stale results\n",
|
||||
"_ = outs[0].zero_()"
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -297,8 +343,12 @@
|
||||
"id": "8fc69c6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once captured, we can replay the graph. This will only replay the kernel launches placed on the CUDA stream.\n",
|
||||
"Any other prepratory work on the host and arguments passed in from python are cached during the capture."
|
||||
"This records/captures all the kernel launches to the CUDA Stream associated with the graph `g`, without actually launching them.\n",
|
||||
"Once captured, we can replay the graph.\n",
|
||||
"\n",
|
||||
"Note that graph replay will only replay the kernel launches placed on the graph's stream\n",
|
||||
"* During graph capture, we must be careful to capture to the correct stream (`torch.cuda.current_stream()` under the graph context)\n",
|
||||
"* Any other preparatory work on the host and arguments passed in from Python are cached during the capture. Changing them would require re-capturing the graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -324,7 +374,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "45d4e739",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -348,12 +398,23 @@
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def with_cuda_graph(x: int):\n",
|
||||
" g.replay()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"time_wo_cuda_graph, _ = benchmark(f\"{num_launches} launches without CUDA Graph\", without_cuda_graph, warmup_it=0, profiling_it=1)\n",
|
||||
"time_w_cuda_graph, _ = benchmark(f\"{num_launches} launches with CUDA Graph\", with_cuda_graph, warmup_it=0, profiling_it=1)\n",
|
||||
"time_wo_cuda_graph, _ = benchmark(\n",
|
||||
" f\"{num_launches} launches without CUDA Graph\",\n",
|
||||
" without_cuda_graph,\n",
|
||||
" warmup_it=0,\n",
|
||||
" profiling_it=1,\n",
|
||||
")\n",
|
||||
"time_w_cuda_graph, _ = benchmark(\n",
|
||||
" f\"{num_launches} launches with CUDA Graph\",\n",
|
||||
" with_cuda_graph,\n",
|
||||
" warmup_it=0,\n",
|
||||
" profiling_it=1,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"Speedup with CUDA Graph: {time_wo_cuda_graph / time_w_cuda_graph:.2f}x\")"
|
||||
]
|
||||
@@ -371,8 +432,8 @@
|
||||
"id": "ee7f9fd2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When applicable, CUTLASS API uses [Apache TVM FFI](https://tvm.apache.org/ffi/) under the hood for invoking compiled DSL kernels from Python.\n",
|
||||
"Apache TVM FFI is an open ABI and FFI for machine learning systems.\n",
|
||||
"[Apache TVM FFI](https://tvm.apache.org/ffi/) is an open ABI and FFI for machine learning systems.\n",
|
||||
"When available, CUTLASS API uses Apache TVM-FFI under the hood as its interface for invoking compiled DSL kernels from Python.\n",
|
||||
"\n",
|
||||
"TVM FFI is enabled by default in CUTLASS API, and is recommended for best performance."
|
||||
]
|
||||
@@ -413,7 +474,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"id": "e8f56be3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -428,10 +489,15 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"original_use_tvm_ffi = cutlass_api.config.GlobalOptions().use_tvm_ffi\n",
|
||||
"\n",
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = True\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def run_iteration(i):\n",
|
||||
" args = cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
|
||||
" args = cutlass_api.arguments.GemmArguments(\n",
|
||||
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
|
||||
" )\n",
|
||||
" return kernel.run(\n",
|
||||
" args,\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
@@ -439,18 +505,35 @@
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_arguments(i: int):\n",
|
||||
" return cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
|
||||
" return cutlass_api.arguments.GemmArguments(\n",
|
||||
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"args_creation_on, args = benchmark(\"[TVM-FFI ON ] Create args\", create_arguments)\n",
|
||||
"compilation_on, compiled = benchmark(\"[TVM-FFI ON ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
|
||||
"compilation_on, compiled = benchmark(\n",
|
||||
" \"[TVM-FFI ON ] Compile kernel\",\n",
|
||||
" lambda i: kernel.compile(args[i]),\n",
|
||||
" warmup_it=2,\n",
|
||||
" profiling_it=5,\n",
|
||||
")\n",
|
||||
"compiled_artifact = compiled[0]\n",
|
||||
"run_on, _ = benchmark(\"[TVM-FFI ON ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
|
||||
"run_on, _ = benchmark(\n",
|
||||
" \"[TVM-FFI ON ] Run kernel\",\n",
|
||||
" lambda i: kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" stream=stream,\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"id": "5a4c2db4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -467,9 +550,25 @@
|
||||
"source": [
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = False\n",
|
||||
"args_creation_off, args = benchmark(\"[TVM-FFI OFF ] Create args\", create_arguments)\n",
|
||||
"compilation_off, compiled = benchmark(\"[TVM-FFI OFF ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
|
||||
"compilation_off, compiled = benchmark(\n",
|
||||
" \"[TVM-FFI OFF ] Compile kernel\",\n",
|
||||
" lambda i: kernel.compile(args[i]),\n",
|
||||
" warmup_it=2,\n",
|
||||
" profiling_it=5,\n",
|
||||
")\n",
|
||||
"compiled_artifact = compiled[0]\n",
|
||||
"run_off, _ = benchmark(\"[TVM-FFI OFF ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
|
||||
"run_off, _ = benchmark(\n",
|
||||
" \"[TVM-FFI OFF ] Run kernel\",\n",
|
||||
" lambda i: kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" stream=stream,\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Restore original setting\n",
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = original_use_tvm_ffi"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
166
python/cutlass_api/examples/004_fake_tensors.ipynb
Normal file
166
python/cutlass_api/examples/004_fake_tensors.ipynb
Normal file
@@ -0,0 +1,166 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4620d513",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Using fake tensors with the CUTLASS API\n",
|
||||
"Fake tensors (e.g., [torch's FakeTensor](https://docs.pytorch.org/docs/2.8/torch.compiler_fake_tensor.html))\n",
|
||||
"are useful for describing the properties of a tensor without actually allocating backing data.\n",
|
||||
"\n",
|
||||
"This example shows how fake tensors can be used within the CUTLASS API\n",
|
||||
"for discovering and compiling a GEMM kernel."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d231b32e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"torch.manual_seed(2025)\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\")\n",
|
||||
" import sys\n",
|
||||
" sys.exit(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f7af2d90",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We first set up operands `A`, `B`, and `out` in torch's `FakeTensorMode`.\n",
|
||||
"These will have all the properties needed for CUTLASS API to construct\n",
|
||||
"the internal representations of tensors used for discovering and compiling\n",
|
||||
"kernels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "9426b66f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"FakeTensor(..., device='cuda:0', size=(128, 512), dtype=torch.float16)\n",
|
||||
"FakeTensor(..., device='cuda:0', size=(512, 256), dtype=torch.float16)\n",
|
||||
"FakeTensor(..., device='cuda:0', size=(128, 256), dtype=torch.float16)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"M, N, K = 128, 256, 512\n",
|
||||
"\n",
|
||||
"with torch._subclasses.fake_tensor.FakeTensorMode():\n",
|
||||
" A_fake = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
|
||||
" B_fake = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
" out_fake = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"print(A_fake)\n",
|
||||
"print(B_fake)\n",
|
||||
"print(out_fake)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4f540c78",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now use these fake tensors to create `GemmArguments`, and use\n",
|
||||
"these to discover and compile a compatible kernel. Note that the same APIs are\n",
|
||||
"used in creating `GemmArguments` as would be used if using\n",
|
||||
"\"real\" tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "e32b700d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args_fake = cutlass_api.arguments.GemmArguments(\n",
|
||||
" A_fake, B_fake, out_fake, accumulator_type=torch.float32)\n",
|
||||
"\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args_fake, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]\n",
|
||||
"compiled_artifact = kernel.compile(args_fake)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07fff511",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `kernel` and `compiled_artifact` discovered using fake tensors\n",
|
||||
"above can now used for running the kernel using real tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "b3034bf1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create real tensors\n",
|
||||
"A_real = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"B_real = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"out_real = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"args_real = cutlass_api.arguments.GemmArguments(\n",
|
||||
" A_real, B_real, out_real, accumulator_type=torch.float32)\n",
|
||||
"\n",
|
||||
"# Run the kernel using the compiled_artifact from resulting\n",
|
||||
"# from compiling with fake tensors.\n",
|
||||
"kernel.run(args_real, compiled_artifact)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "09871eca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ref = A_real @ B_real\n",
|
||||
"torch.testing.assert_close(out_real, ref)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user