mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 14:28:59 +00:00
* v4.3 update. * Update the cute_dsl_api changelog's doc link * Update version to 4.3.0 * Update the example link * Update doc to encourage user to install DSL from requirements.txt --------- Co-authored-by: Larry Wu <larwu@nvidia.com>
461 lines
18 KiB
Plaintext
461 lines
18 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"import cutlass\n",
|
|
"import cutlass.cute as cute\n",
|
|
"import cutlass.cute.testing as testing\n",
|
|
"import cutlass.torch as cutlass_torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The Usage of Benchmark and Autotune Utilities in CuTe DSL\n",
|
|
"\n",
|
|
"CuTe DSL provides autotune and benchmark utilities to help users evaluate and optimize kernel performance. This notebook demonstrates how to use these tools.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"### Autotune\n",
|
|
"\n",
|
|
"We provides two kinds of autotune utilities for users: `autotune.jit` decorator and the `tune` function. The former is used as a decorator used on top of `@cute.jit` while the latter is used as an individual function.\n",
|
|
"\n",
|
|
"#### @autotune.jit\n",
|
|
"\n",
|
|
"We take the `elementwise_add_kernel` as an example. After writing the jit host function and kernel, we could add the `@autotune_jit` decorator on top of the jit host function to enable autotune. \n",
|
|
"```python\n",
|
|
"@testing.autotune_jit(\n",
|
|
" params_dict={\"copy_bits\": [64, 128]},\n",
|
|
" update_on_change=[\"M\", \"N\"],\n",
|
|
" warmup_iterations=100,\n",
|
|
" iterations=100,\n",
|
|
")\n",
|
|
"```\n",
|
|
"\n",
|
|
"The `autotune_jit` decorator provides several parameters to control the autotuning process:\n",
|
|
"\n",
|
|
"- params_dict: A dictionary containing the parameters to be tuned and their possible values\n",
|
|
"- update_on_change: A list of argument names that trigger re-tuning when their values change\n",
|
|
"- warmup_iterations: Number of warmup iterations before timing\n",
|
|
"- iterations: Number of iterations for timing each parameter combination\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@cute.kernel\n",
|
|
"def elementwise_add_kernel(\n",
|
|
" gA: cute.Tensor,\n",
|
|
" gB: cute.Tensor,\n",
|
|
" gC: cute.Tensor,\n",
|
|
" cC: cute.Tensor, # coordinate tensor\n",
|
|
" shape: cute.Shape,\n",
|
|
" thr_layout: cute.Layout,\n",
|
|
" val_layout: cute.Layout,\n",
|
|
"):\n",
|
|
" tidx, _, _ = cute.arch.thread_idx()\n",
|
|
" bidx, _, _ = cute.arch.block_idx()\n",
|
|
"\n",
|
|
" # slice for CTAs\n",
|
|
" # logical id -> address\n",
|
|
" blk_coord = ((None, None), bidx)\n",
|
|
" blkA = gA[blk_coord] # (TileM,TileN)\n",
|
|
" blkB = gB[blk_coord] # (TileM,TileN)\n",
|
|
" blkC = gC[blk_coord] # (TileM,TileN)\n",
|
|
" blkCrd = cC[blk_coord] # (TileM, TileN)\n",
|
|
"\n",
|
|
" # # declare the atoms which will be used later for memory copy\n",
|
|
" copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)\n",
|
|
" copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type)\n",
|
|
"\n",
|
|
" tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)\n",
|
|
" tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)\n",
|
|
" tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)\n",
|
|
"\n",
|
|
" thr_copy_A = tiled_copy_A.get_slice(tidx)\n",
|
|
" thr_copy_B = tiled_copy_B.get_slice(tidx)\n",
|
|
" thr_copy_C = tiled_copy_C.get_slice(tidx)\n",
|
|
"\n",
|
|
" thrA = thr_copy_A.partition_S(blkA)\n",
|
|
" thrB = thr_copy_B.partition_S(blkB)\n",
|
|
" thrC = thr_copy_C.partition_S(blkC)\n",
|
|
"\n",
|
|
" # allocate fragments for gmem->rmem\n",
|
|
" frgA = cute.make_fragment_like(thrA)\n",
|
|
" frgB = cute.make_fragment_like(thrB)\n",
|
|
" frgC = cute.make_fragment_like(thrC)\n",
|
|
"\n",
|
|
" thrCrd = thr_copy_C.partition_S(blkCrd)\n",
|
|
" frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)\n",
|
|
"\n",
|
|
" for i in range(0, cute.size(frgPred), 1):\n",
|
|
" val = cute.elem_less(thrCrd[i], shape)\n",
|
|
" frgPred[i] = val\n",
|
|
"\n",
|
|
" ##########################################################\n",
|
|
" # Move data to reg address space\n",
|
|
" ##########################################################\n",
|
|
"\n",
|
|
" cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)\n",
|
|
" cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)\n",
|
|
"\n",
|
|
" # Load data before use. The compiler will optimize the copy and load\n",
|
|
" # operations to convert some memory ld/st into register uses.\n",
|
|
" result = frgA.load() + frgB.load()\n",
|
|
"\n",
|
|
" # Save the results back to registers. Here we reuse b's registers.\n",
|
|
" frgC.store(result)\n",
|
|
"\n",
|
|
" # Copy the results back to c\n",
|
|
" cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)\n",
|
|
"\n",
|
|
"\n",
|
|
"@testing.autotune_jit(\n",
|
|
" params_dict={\"copy_bits\": [64, 128]},\n",
|
|
" update_on_change=[\"M\", \"N\"],\n",
|
|
" warmup_iterations=100,\n",
|
|
" iterations=100,\n",
|
|
")\n",
|
|
"@cute.jit\n",
|
|
"def elementwise_add_autotune(mA, mB, mC, M, N, copy_bits: cutlass.Constexpr = 128):\n",
|
|
" dtype = mA.element_type\n",
|
|
" vector_size = copy_bits // dtype.width\n",
|
|
"\n",
|
|
" thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))\n",
|
|
" val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))\n",
|
|
" tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)\n",
|
|
"\n",
|
|
" gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))\n",
|
|
" gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))\n",
|
|
" gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))\n",
|
|
" idC = cute.make_identity_tensor(mC.shape)\n",
|
|
" cC = cute.zipped_divide(idC, tiler=tiler_mn)\n",
|
|
"\n",
|
|
" elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(\n",
|
|
" grid=[cute.size(gC, mode=[1]), 1, 1],\n",
|
|
" block=[cute.size(tv_layout, mode=[0]), 1, 1],\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"When we run the jit funciton `elementwise_add_autotune`, the CuTe DSL will help us tune the kernels by looping the specified configs and run the kernel with the best config.\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"M, N = 1024, 1024\n",
|
|
"dtype = cutlass.Float32\n",
|
|
"skip_ref_check = False\n",
|
|
"\n",
|
|
"print(f\"\\nRunning Elementwise Add test with:\")\n",
|
|
"print(f\"Tensor dimensions: [{M}, {N}]\")\n",
|
|
"print(f\"Input and Output Data type: {dtype}\")\n",
|
|
"\n",
|
|
"torch_dtype = cutlass_torch.dtype(dtype)\n",
|
|
"\n",
|
|
"a = torch.randn(M, N, device=torch.device(\"cuda\"), dtype=torch_dtype)\n",
|
|
"b = torch.randn(M, N, device=torch.device(\"cuda\"), dtype=torch_dtype)\n",
|
|
"\n",
|
|
"c = torch.zeros_like(a)\n",
|
|
"\n",
|
|
"print(f\"Input tensor shapes:\")\n",
|
|
"print(f\"a: {a.shape}, dtype: {a.dtype}\")\n",
|
|
"print(f\"b: {b.shape}, dtype: {b.dtype}\")\n",
|
|
"print(f\"c: {c.shape}, dtype: {c.dtype}\\n\")\n",
|
|
"\n",
|
|
"elementwise_add_autotune(a, b, c, M, N)\n",
|
|
"\n",
|
|
"if not skip_ref_check:\n",
|
|
" print(\"Verifying results for autotuned function ...\")\n",
|
|
" torch.testing.assert_close(a + b, c)\n",
|
|
" print(\"Results verified successfully!\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The output is as follows:\n",
|
|
"\n",
|
|
"```\n",
|
|
"Running Elementwise Add test with:\n",
|
|
"Tensor dimensions: [1024, 1024]\n",
|
|
"Input and Output Data type: Float32\n",
|
|
"Input tensor shapes:\n",
|
|
"a: torch.Size([1024, 1024]), dtype: torch.float32\n",
|
|
"b: torch.Size([1024, 1024]), dtype: torch.float32\n",
|
|
"c: torch.Size([1024, 1024]), dtype: torch.float32\n",
|
|
"Verifying results for autotuned function ...\n",
|
|
"Results verified successfully!\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"To monitor the autotuning process in detail, you can enable logging by setting the environment variable `CUTE_DSL_LOG_AUTOTUNE`. \n",
|
|
"```shell\n",
|
|
"export CUTE_DSL_LOG_AUTOTUNE=1\n",
|
|
"```\n",
|
|
"This will display comprehensive information including:\n",
|
|
"- Each configuration being evaluated and its corresponding execution time\n",
|
|
"- The optimal configuration that was selected\n",
|
|
"- Total time spent on tuning\n",
|
|
"- Cache hit/miss statistics\n",
|
|
"\n",
|
|
"\n",
|
|
"Below is a sample output showing the autotuning process with different configurations:\n",
|
|
"```python\n",
|
|
"2025-07-23 06:17:03,978 - cutlass.cute.testing_Autotune - INFO - Tuning configuration: {'copy_bits': 64}\n",
|
|
"2025-07-23 06:17:04,519 - cutlass.cute.testing_Autotune - INFO - Execution time: 0.010857919985428453 us\n",
|
|
"2025-07-23 06:17:04,519 - cutlass.cute.testing_Autotune - INFO - Tuning configuration: {'copy_bits': 128}\n",
|
|
"2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Execution time: 0.011117440033704042 us\n",
|
|
"2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Best configuration: {'copy_bits': 64}, execution time: 0.010857919985428453 us\n",
|
|
"2025-07-23 06:17:04,683 - cutlass.cute.testing_Autotune - INFO - Total tuning time: 0.7053244113922119 s\n",
|
|
"...\n",
|
|
"2025-07-23 06:17:04,700 - cutlass.cute.testing_Autotune - INFO - Using cached best configuration: {'copy_bits': 64}\n",
|
|
"```\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### tune\n",
|
|
"\n",
|
|
"We also provide a `tune` funtion. The interface of the `tune` function is as follows:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"def tune(\n",
|
|
" func: Callable[[Any], Callable[[], Any]],\n",
|
|
" params_dict: Dict[str, List[Any]] = None,\n",
|
|
" kernel_arguments: JitArguments = JitArguments(),\n",
|
|
" warmup_iterations=10,\n",
|
|
" iterations=100,\n",
|
|
" stream: Optional[cuda_driver.CUstream] = None,\n",
|
|
") -> Dict[str, Any]:\n",
|
|
"```\n",
|
|
"\n",
|
|
"The `tune` function takes the following parameters:\n",
|
|
"\n",
|
|
"- func: A callable that takes configuration parameters and returns a kernel function\n",
|
|
"- params_dict: Dictionary mapping parameter names to lists of possible values to tune\n",
|
|
"- kernel_arguments: Arguments to pass to the kernel for tuning\n",
|
|
"- warmup_iterations: Number of warmup iterations before timing (default: 10)\n",
|
|
"- iterations: Number of timing iterations per configuration (default: 100)\n",
|
|
"- stream: Optional CUDA stream to use for execution. defaults to default CUDA stream. The stream parameter must match the stream passed to the kernel, mismatched streams will result in an error.\n",
|
|
"\n",
|
|
"It returns a dictionary containing the best kernel configuration found.\n",
|
|
"\n",
|
|
"\n",
|
|
"Here is an example to use the `tune` function:\n",
|
|
"\n",
|
|
"1. First remove the `@testing.autotune_jit` decorator from the `elementwise_add_autotune` function:\n",
|
|
" ```python\n",
|
|
" @testing.autotune_jit(\n",
|
|
" params_dict={\"copy_bits\": [64, 128]},\n",
|
|
" update_on_change=[\"M\", \"N\"], \n",
|
|
" warmup_iterations=100,\n",
|
|
" iterations=100,\n",
|
|
" )\n",
|
|
" ```\n",
|
|
"\n",
|
|
" 2. Define a `tune_func` that:\n",
|
|
" - Takes input tensors (a, b, c), dimensions (M, N) and tuning parameter copy_bits\n",
|
|
" - Compiles the `elementwise_add_autotune` function using `cute.compile()`\n",
|
|
" - Returns a lambda function that executes the compiled kernel\n",
|
|
"\n",
|
|
" 3. Pass `tune_func` to `testing.tune` function along with:\n",
|
|
" - Parameter space to explore (copy_bits values)\n",
|
|
" - Kernel arguments wrapped in JitArguments\n",
|
|
" - The `tune` function will find optimal parameters automatically\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def tune_func(a, b, c, M, N, copy_bits=128):\n",
|
|
" compiled_func = cute.compile(elementwise_add_autotune, a, b, c, M, N, copy_bits=128)\n",
|
|
" return lambda: compiled_func(a, b, c, M, N)\n",
|
|
"\n",
|
|
"params = testing.tune(\n",
|
|
" tune_func,\n",
|
|
" params_dict={\"copy_bits\": [64, 128]},\n",
|
|
" kernel_arguments=testing.JitArguments(a, b, c, M, N),\n",
|
|
")\n",
|
|
"print(f\"The best kernel configs found: {params}\")\n",
|
|
"\n",
|
|
"# run the kernel with the best config\n",
|
|
"compiled_func = cute.compile(elementwise_add_autotune, a, b, c, M, N, **params)\n",
|
|
"compiled_func(a, b, c, M, N)\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### benchmark\n",
|
|
"\n",
|
|
"In CuTe DSL, the benchmark utility can be used to measure kernel execution time. The interface of benchmark routine is as follows:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"def benchmark(\n",
|
|
" callable: Callable,\n",
|
|
" *,\n",
|
|
" warmup_iterations: int = 10,\n",
|
|
" iterations: int = 100,\n",
|
|
" stream: Optional[cuda_driver.CUstream] = None,\n",
|
|
" kernel_arguments: Optional[JitArguments] = None,\n",
|
|
" workspace_generator: Optional[Callable[[], JitArguments]] = None,\n",
|
|
" workspace_count: int = 1,\n",
|
|
" use_cuda_graphs: bool = False,\n",
|
|
") -> float:\n",
|
|
"```\n",
|
|
"\n",
|
|
"The benchmark utility exposes several key configuration parameters to control profiling behavior:\n",
|
|
"\n",
|
|
"- callable: The function to be benchmarked\n",
|
|
"- warmup_iterations: Controls the number of initial warmup iterations before measurement begins (default: 10)\n",
|
|
"- iterations: Specifies how many iterations to profile for performance measurement (default: 100)\n",
|
|
"- stream: Designates which CUDA stream to execute the kernel on (default: default stream) \n",
|
|
"- use_cuda_graphs: Whether enables CUDA graph for the callable function to minimize kernel launch overhead (default: False)\n",
|
|
"- workspace_generator: Provides a function that generates fresh kernel arguments each iteration to avoid caching effects\n",
|
|
"- workspace_count: Determines how many different workspaces to cycle through during profiling (default: 1)\n",
|
|
"\n",
|
|
"When benchmarking, there are several key parameters that can be configured:\n",
|
|
"\n",
|
|
"1. Core parameters:\n",
|
|
" - The function to profile (callable)\n",
|
|
" - Number of warmup iterations before measurement\n",
|
|
" - Number of profiling iterations for measurement\n",
|
|
"\n",
|
|
"2. Stream configuration:\n",
|
|
" - For kernels running in non-default streams, the stream must be specified\n",
|
|
" - The stream parameter must match the stream passed to the kernel, mismatched streams will result in an error\n",
|
|
"\n",
|
|
"3. Cache effects mitigation:\n",
|
|
" - To prevent L2 cache effects from skewing results, multiple workspaces can be cycled through\n",
|
|
" - This is controlled via workspace_count and workspace_generator parameters\n",
|
|
" - Each workspace provides fresh kernel arguments\n",
|
|
"\n",
|
|
"4. CUDA Graph support:\n",
|
|
" - Enables measuring kernel execution time without host overhead\n",
|
|
" - Requires the callable to be decorated with @cute.jit\n",
|
|
" - Must use a non-default CUDA stream when using graphs\n",
|
|
"\n",
|
|
"This function will return the execution time of the callable in microseconds. As GPU frequency can vary dynamically, we could fix the SM and memory frequencies to get more stable and reproducible benchmark results. This can be done by setting the GPU clocks using nvidia-smi before running the benchmark. In the next, let's use the benchmark function to get the execution time of the above elementwise_add kernel."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def generate_kernel_arguments():\n",
|
|
" a = torch.randn(\n",
|
|
" M, N, device=torch.device(\"cuda\"), dtype=torch_dtype\n",
|
|
" )\n",
|
|
" b = torch.randn(\n",
|
|
" M, N, device=torch.device(\"cuda\"), dtype=torch_dtype\n",
|
|
" )\n",
|
|
"\n",
|
|
" c = torch.zeros_like(a)\n",
|
|
"\n",
|
|
" return testing.JitArguments(a, b, c, M, N)\n",
|
|
"\n",
|
|
"avg_time_us = testing.benchmark(\n",
|
|
" elementwise_add_autotune,\n",
|
|
" workspace_generator=generate_kernel_arguments,\n",
|
|
" workspace_count=10,\n",
|
|
" warmup_iterations=10,\n",
|
|
" iterations=100,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Print execution results\n",
|
|
"print(\n",
|
|
" f\"Kernel execution time for cute.jit kernel with M={M}, N={N}: {avg_time_us / 1e3:.4f} ms\"\n",
|
|
")\n",
|
|
"print(\n",
|
|
" f\"Achieved memory throughput for M={M}, N={N}: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"After running the code, we will get output similar to the following:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"```\n",
|
|
"Kernel execution time for cute.jit kernel with M=1024, N=1024: 0.0403 ms\n",
|
|
"Achieved memory throughput for M=1024, N=1024: 312.37 GB/s\n",
|
|
"```"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|