2026-01-12 updates

This commit is contained in:
jkosaian
2026-01-12 18:51:25 -08:00
parent 7c09485e25
commit 87cab7bae2
27 changed files with 7619 additions and 234 deletions

View File

@@ -5,7 +5,7 @@
"id": "3dd45ef2",
"metadata": {},
"source": [
"# Basic GEMM using CUTLASS Python API"
"# Basic GEMM using CUTLASS API"
]
},
{

View File

@@ -229,7 +229,7 @@
" self, args: GemmArguments, cc: int = None\n",
") -> cutlass_api.artifact.CompiledArtifact:\n",
" stream = cutlass.cute.runtime.make_fake_stream()\n",
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
" compiled_gemm = self.cute_compile(self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream)\n",
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
]
},
@@ -260,7 +260,7 @@
"):\n",
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
" compiled_gemm = artifact.compiled_obj\n",
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
" self.cute_run(compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream)"
]
},
{
@@ -343,13 +343,13 @@
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
" for stride_A, stride_B, stride_out in stride_combos:\n",
" # Create TensorAttributes for A, B, and out tensors\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
" a_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
" cutlass.Float64, stride_A, divisibility\n",
" )\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
" b_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
" cutlass.Float64, stride_B, divisibility\n",
" )\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
" out_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
" cutlass.Float64, stride_out, divisibility\n",
" )\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",

View File

@@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "91d43c2b",
"metadata": {},
"source": [
"# Grouped GEMM with contiguous tensors via the CUTLASS API\n",
"\n",
"Note: this notebook requires a GPU with compute capability 100:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f671f602",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100.\\n{status.error}\"\n",
" )\n",
" import sys\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "bc4adf7d",
"metadata": {},
"source": [
"This notebook shows how to use the CUTLASS API to discover, compile, and execute\n",
"kernels supporting contiguous offset grouped GEMMs.\n",
"\n",
"In a \"contiguous offset\" grouped GEMM, `G` different problems are executed\n",
"in which problems differ only in the `M` mode. Their problem sizes are thus\n",
"represented as:\n",
"\n",
"```text\n",
"M0 x N x K\n",
"M1 x N x K\n",
"M2 x N x K\n",
"...\n",
"M(G-1) x N x K\n",
"```\n",
"\n",
"The grouped GEMM is referred to as \"contiguous\" because operands for different\n",
"problems in the group are contained within contiguous tensors.\n",
"\n",
"Rather than having `G` different tensors for each of operands `A` and `B`, tensors\n",
"for different problems in the group are packed together:\n",
"* `A` is of shape `(TotalM, K)`, where `TotalM` is the sum of all `M` modes for problems in the group.\n",
"The `A` operands for each problem in the group are stacked along the `M` mode to form this input. More on this below.\n",
"* `B` is of shape `(G, K, N)`, where `B[i, :, :]` represents the GEMM `B` operand for the `i`th problem in the group.\n",
"\n",
"For example, with `G=3` (three problems in the group), with `M` modes of M0, M1, and M2,\n",
"respectively, the tensor `A` would be laid out as follows:\n",
"\n",
"```text\n",
"\n",
" +----------------------------------+ ^ \n",
" | | | | \n",
" | A0 | M0 | \n",
" | | | | \n",
" |- - - - - - - - - - - -| | \n",
" | | | |\n",
" | | | TotalM \n",
" | A1 | M1 |\n",
" | | | |\n",
" | | | | \n",
" |- - - - - - - - - - - -| | \n",
" | A2 | M2 | \n",
" +----------------------------------+ v \n",
"```\n",
"\n",
"The extents of individual `A` operands packed within the overall contiguous offset `A` tensor\n",
"are provided by an auxiliary `offsets` vector of shape `(G,)`. `offsets[i]` indicates the ending\n",
"M coordinate (exclusive) for the `i`th `A` operand.\n",
"\n",
"Thus, for the example above, `offsets = [M0, M0 + M1, M0 + M1 + M2]`.\n",
"\n",
"The output of the operation is of shape `(TotalM, N)`. The `i`th output occupies `out[start:end, :]`,\n",
"where `start` and `end` are `offsets[i-1]` and `offsets[i]`, respectively (unless `i=0`, in which case\n",
"`start` is 0).\n",
"\n",
"The reference code below shows the computation of this kernel."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6185f60a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype):\n",
" G, K, N = B.shape\n",
" TotalM = A.shape[0]\n",
"\n",
" out = torch.empty((TotalM, N), dtype=out_dtype, device=A.device)\n",
"\n",
" start = 0\n",
" for i in range(G):\n",
" end = offsets[i]\n",
" out[start:end, :] = A[start:end, :] @ B[i, :, :]\n",
" start = end\n",
"\n",
" return out"
]
},
{
"cell_type": "markdown",
"id": "d0bf2f91",
"metadata": {},
"source": [
"## Contiguous offset grouped GEMM in PyTorch"
]
},
{
"cell_type": "markdown",
"id": "4308a6a2",
"metadata": {},
"source": [
"The same operation is performed by `torch`'s `torch._grouped_mm` (torch < 2.10)\n",
"and `torch.nn.functional.grouped_mm` (torch >= 2.10)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "043906af",
"metadata": {},
"outputs": [],
"source": [
"TotalM = 8192\n",
"G = 12\n",
"K = 1024\n",
"N = 2048\n",
"\n",
"offsets = torch.arange(TotalM // G, TotalM, TotalM // G, device=\"cuda\", dtype=torch.int32)\n",
"offsets[-1] = TotalM\n",
"\n",
"A = torch.randn(TotalM, K, device=\"cuda\", dtype=torch.bfloat16)\n",
"B = torch.randn(G, N, K, device=\"cuda\", dtype=torch.bfloat16).permute(0, 2, 1)\n",
"\n",
"out_torch = torch._grouped_mm(A, B, offsets, out_dtype=torch.bfloat16)\n",
"reference = reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype=torch.bfloat16)\n",
"\n",
"torch.testing.assert_close(out_torch, reference)"
]
},
{
"cell_type": "markdown",
"id": "0d0e9479",
"metadata": {},
"source": [
"## Contiguous offset grouped GEMM in CUTLASS API\n",
"\n",
"CUTLASS API exposes this contiguous offset grouped GEMM via `GroupedGemmArguments`,\n",
"which are constructed similarly to `GemmArguments`, but take in an `offsets`\n",
"tensor as well:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ff8d3ef1",
"metadata": {},
"outputs": [],
"source": [
"out = torch.empty((TotalM, N), device=\"cuda\", dtype=torch.bfloat16)\n",
"\n",
"args = cutlass_api.arguments.GroupedGemmArguments(\n",
" A,\n",
" B,\n",
" out,\n",
" accumulator_type=torch.float32,\n",
" offsets=offsets,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0dc6d1cb",
"metadata": {},
"source": [
"One can then use the same APIs for finding, compiling, and executing a\n",
"kernel supporting this operation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "80213e1e",
"metadata": {},
"outputs": [],
"source": [
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"\n",
"assert kernels, \"No kernels found\"\n",
"\n",
"# Select the first kernel found for simplicity\n",
"kernel = kernels[0]\n",
"\n",
"compiled_kernel = kernel.compile(args)\n",
"\n",
"# Execute the kernel\n",
"kernel.run(args, compiled_artifact=compiled_kernel)\n",
"\n",
"torch.testing.assert_close(out, reference)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}