mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-20 21:08:57 +00:00
2026-01-12 updates
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
"id": "3dd45ef2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Basic GEMM using CUTLASS Python API"
|
||||
"# Basic GEMM using CUTLASS API"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -229,7 +229,7 @@
|
||||
" self, args: GemmArguments, cc: int = None\n",
|
||||
") -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
" stream = cutlass.cute.runtime.make_fake_stream()\n",
|
||||
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
|
||||
" compiled_gemm = self.cute_compile(self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream)\n",
|
||||
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
|
||||
]
|
||||
},
|
||||
@@ -260,7 +260,7 @@
|
||||
"):\n",
|
||||
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
|
||||
" compiled_gemm = artifact.compiled_obj\n",
|
||||
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
|
||||
" self.cute_run(compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -343,13 +343,13 @@
|
||||
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
|
||||
" for stride_A, stride_B, stride_out in stride_combos:\n",
|
||||
" # Create TensorAttributes for A, B, and out tensors\n",
|
||||
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" a_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
|
||||
" cutlass.Float64, stride_A, divisibility\n",
|
||||
" )\n",
|
||||
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" b_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
|
||||
" cutlass.Float64, stride_B, divisibility\n",
|
||||
" )\n",
|
||||
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" out_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
|
||||
" cutlass.Float64, stride_out, divisibility\n",
|
||||
" )\n",
|
||||
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "91d43c2b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Grouped GEMM with contiguous tensors via the CUTLASS API\n",
|
||||
"\n",
|
||||
"Note: this notebook requires a GPU with compute capability 100:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f671f602",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100})):\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability 100.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
" sys.exit(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bc4adf7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook shows how to use the CUTLASS API to discover, compile, and execute\n",
|
||||
"kernels supporting contiguous offset grouped GEMMs.\n",
|
||||
"\n",
|
||||
"In a \"contiguous offset\" grouped GEMM, `G` different problems are executed\n",
|
||||
"in which problems differ only in the `M` mode. Their problem sizes are thus\n",
|
||||
"represented as:\n",
|
||||
"\n",
|
||||
"```text\n",
|
||||
"M0 x N x K\n",
|
||||
"M1 x N x K\n",
|
||||
"M2 x N x K\n",
|
||||
"...\n",
|
||||
"M(G-1) x N x K\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The grouped GEMM is referred to as \"contiguous\" because operands for different\n",
|
||||
"problems in the group are contained within contiguous tensors.\n",
|
||||
"\n",
|
||||
"Rather than having `G` different tensors for each of operands `A` and `B`, tensors\n",
|
||||
"for different problems in the group are packed together:\n",
|
||||
"* `A` is of shape `(TotalM, K)`, where `TotalM` is the sum of all `M` modes for problems in the group.\n",
|
||||
"The `A` operands for each problem in the group are stacked along the `M` mode to form this input. More on this below.\n",
|
||||
"* `B` is of shape `(G, K, N)`, where `B[i, :, :]` represents the GEMM `B` operand for the `i`th problem in the group.\n",
|
||||
"\n",
|
||||
"For example, with `G=3` (three problems in the group), with `M` modes of M0, M1, and M2,\n",
|
||||
"respectively, the tensor `A` would be laid out as follows:\n",
|
||||
"\n",
|
||||
"```text\n",
|
||||
"\n",
|
||||
" +----------------------------------+ ^ \n",
|
||||
" | | | | \n",
|
||||
" | A0 | M0 | \n",
|
||||
" | | | | \n",
|
||||
" |- - - - - - - - - - - -| | \n",
|
||||
" | | | |\n",
|
||||
" | | | TotalM \n",
|
||||
" | A1 | M1 |\n",
|
||||
" | | | |\n",
|
||||
" | | | | \n",
|
||||
" |- - - - - - - - - - - -| | \n",
|
||||
" | A2 | M2 | \n",
|
||||
" +----------------------------------+ v \n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The extents of individual `A` operands packed within the overall contiguous offset `A` tensor\n",
|
||||
"are provided by an auxiliary `offsets` vector of shape `(G,)`. `offsets[i]` indicates the ending\n",
|
||||
"M coordinate (exclusive) for the `i`th `A` operand.\n",
|
||||
"\n",
|
||||
"Thus, for the example above, `offsets = [M0, M0 + M1, M0 + M1 + M2]`.\n",
|
||||
"\n",
|
||||
"The output of the operation is of shape `(TotalM, N)`. The `i`th output occupies `out[start:end, :]`,\n",
|
||||
"where `start` and `end` are `offsets[i-1]` and `offsets[i]`, respectively (unless `i=0`, in which case\n",
|
||||
"`start` is 0).\n",
|
||||
"\n",
|
||||
"The reference code below shows the computation of this kernel."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "6185f60a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"def reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype):\n",
|
||||
" G, K, N = B.shape\n",
|
||||
" TotalM = A.shape[0]\n",
|
||||
"\n",
|
||||
" out = torch.empty((TotalM, N), dtype=out_dtype, device=A.device)\n",
|
||||
"\n",
|
||||
" start = 0\n",
|
||||
" for i in range(G):\n",
|
||||
" end = offsets[i]\n",
|
||||
" out[start:end, :] = A[start:end, :] @ B[i, :, :]\n",
|
||||
" start = end\n",
|
||||
"\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d0bf2f91",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Contiguous offset grouped GEMM in PyTorch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4308a6a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The same operation is performed by `torch`'s `torch._grouped_mm` (torch < 2.10)\n",
|
||||
"and `torch.nn.functional.grouped_mm` (torch >= 2.10)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "043906af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TotalM = 8192\n",
|
||||
"G = 12\n",
|
||||
"K = 1024\n",
|
||||
"N = 2048\n",
|
||||
"\n",
|
||||
"offsets = torch.arange(TotalM // G, TotalM, TotalM // G, device=\"cuda\", dtype=torch.int32)\n",
|
||||
"offsets[-1] = TotalM\n",
|
||||
"\n",
|
||||
"A = torch.randn(TotalM, K, device=\"cuda\", dtype=torch.bfloat16)\n",
|
||||
"B = torch.randn(G, N, K, device=\"cuda\", dtype=torch.bfloat16).permute(0, 2, 1)\n",
|
||||
"\n",
|
||||
"out_torch = torch._grouped_mm(A, B, offsets, out_dtype=torch.bfloat16)\n",
|
||||
"reference = reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype=torch.bfloat16)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(out_torch, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0d0e9479",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Contiguous offset grouped GEMM in CUTLASS API\n",
|
||||
"\n",
|
||||
"CUTLASS API exposes this contiguous offset grouped GEMM via `GroupedGemmArguments`,\n",
|
||||
"which are constructed similarly to `GemmArguments`, but take in an `offsets`\n",
|
||||
"tensor as well:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "ff8d3ef1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out = torch.empty((TotalM, N), device=\"cuda\", dtype=torch.bfloat16)\n",
|
||||
"\n",
|
||||
"args = cutlass_api.arguments.GroupedGemmArguments(\n",
|
||||
" A,\n",
|
||||
" B,\n",
|
||||
" out,\n",
|
||||
" accumulator_type=torch.float32,\n",
|
||||
" offsets=offsets,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0dc6d1cb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One can then use the same APIs for finding, compiling, and executing a\n",
|
||||
"kernel supporting this operation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "80213e1e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"\n",
|
||||
"assert kernels, \"No kernels found\"\n",
|
||||
"\n",
|
||||
"# Select the first kernel found for simplicity\n",
|
||||
"kernel = kernels[0]\n",
|
||||
"\n",
|
||||
"compiled_kernel = kernel.compile(args)\n",
|
||||
"\n",
|
||||
"# Execute the kernel\n",
|
||||
"kernel.run(args, compiled_artifact=compiled_kernel)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user