From b6a914488fcd295ab43511c41028788e818c8175 Mon Sep 17 00:00:00 2001 From: John Afaganis Date: Wed, 4 Mar 2026 13:15:34 -0700 Subject: [PATCH] Add Operation Support Matrix to Dispatcher README (#5071) Added an Operation Support Matrix to the Dispatcher README, detailing CK Tile operations with support status for various data types, layouts, and GPU targets. ## Motivation Provide a clear understanding of which operators (and variants) are supported by dispatcher. ## Technical Details Entirely generated by a skill. ## Test Plan N/A. This is a documentation-only change. ## Test Result N/A. This is a documentation-only change. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- dispatcher/README.md | 107 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 2 deletions(-) diff --git a/dispatcher/README.md b/dispatcher/README.md index fa3fbd3a59..d1ca299d78 100644 --- a/dispatcher/README.md +++ b/dispatcher/README.md @@ -16,8 +16,9 @@ A unified kernel dispatch system for AMD GPUs with C++ and Python frontends. 5. [Running Examples](#running-examples) 6. [External Integration](#external-integration) 7. [Core Concepts](#core-concepts) -8. [Troubleshooting](#troubleshooting) -9. [File Structure](#file-structure) +8. [Operation Support Matrix](#operation-support-matrix) +9. [Troubleshooting](#troubleshooting) +10. [File Structure](#file-structure) --- @@ -618,6 +619,108 @@ auto problem = ProblemBuilder() --- +## Operation Support Matrix + +This matrix shows all CK Tile operations with per-data-type, per-layout, and per-GPU support status. It uses a three-state convention: ✅ = supported by both CK Tile and the dispatcher, ❌ = supported by CK Tile but not yet in the dispatcher, blank = not supported by CK Tile itself. + +| | | | | | **Data Types** | | | | | **Layouts** | | | | **GPU Targets** | | | +|:---:|---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| **Op** | **CK Tile Kernel** | **fp16** | **fp8** | **bf16** | **bf8** | **int8** | **fp4** | **fp6** | **rcr** | **rrr** | **ccr** | **crr** | **90a** | **942** | **950** | **1201** | +| GEMM | gemm_multi_d [5]
engine: `dispatcher/`
example: `19_gemm_multi_d/` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| GEMM | gemm_preshuffle [1][2]
engine: `dispatcher/` | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | | | ✅ | ✅ | ✅ | ❌ | +| GEMM | gemm_universal [3][4][7][8]
engine: `dispatcher/`
example: `03_gemm/` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| GEMM | batched_contraction
example: `41_batched_contraction/` | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | batched_gemm
example: `16_batched_gemm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | block_scale_gemm
example: `38_block_scale_gemm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | flatmm
example: `18_flatmm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | gemm_multi_abd
example: `22_gemm_multi_abd/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | gemm_quant | | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | grouped_gemm
example: `17_grouped_gemm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | grouped_gemm_quant | | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GEMM | streamk_gemm
example: `40_streamk_gemm/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Reduce | multi_reduce2d
example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Reduce | reduce2d
example: `05_reduce/` | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Attention | fmha
example: `01_fmha/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Attention | sparse_attn
example: `50_sparse_attn/` | ❌ | | ❌ | | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Activation | softmax | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Activation | topk_softmax
example: `09_topk_softmax/` | ❌ | ❌ | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Conv | grouped_conv [6]
example: `20_grouped_convolution/` | ❌ | ❌ | ❌ | ❌ | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Data Move | batched_transpose
example: `35_batched_transpose/` | ❌ | ❌ | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Data Move | image_to_column
example: `04_img2col/` | ❌ | | | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Data Move | permute
example: `06_permute/` | ❌ | ❌ | ❌ | | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Elementwise | elementwise
example: `21_elementwise/` | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | ❌ | ❌ | ❌ | ❌ | +| MoE | fused_moe
example: `15_fused_moe/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Norm | add_rmsnorm2d_rdquant
example: `11_add_rmsnorm2d_rdquant/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Norm | layernorm2d
example: `02_layernorm2d/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Norm | norm_reduce | ❌ | | ❌ | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Norm | rmsnorm2d
example: `10_rmsnorm2d/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Pooling | pooling
example: `36_pooling/` | ❌ | | | | | | | | | | | ❌ | ❌ | ❌ | ❌ | +| Quant | smoothquant
example: `12_smoothquant/` | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | ❌ | ❌ | ❌ | ❌ | + +**Notes:** + +- [1] **gemm_preshuffle:** Supports only `rcr` layout. Uses fixed `preshufflev2` pipeline, `Auto` scheduler, and `cshuffle` epilogue. +- [2] **gemm_preshuffle:** `int8` preshuffle support is limited to gfx942 and gfx950 (entries in `preshuffle_warp_tile_combos`). +- [3] **gemm_universal:** `fp4` (pk_fp4) support is only available on gfx950. +- [4] **gemm_universal:** `fp32` GEMM is supported by the dispatcher (`fp32_fp32_fp32` warp tile combos exist) but is omitted from matrix columns for consistency with the tile engine matrix format. +- [5] **gemm_multi_d:** Codegen supports `MultiDAdd` and `MultiDMultiply` element-wise ops. Preselected kernel sets also test `Relu`, `Gelu`, `FastGelu`. +- [6] **grouped_conv:** `arch_filter.py` defines conv operator types (`CONV_FWD`, `CONV_BWD_DATA`, `CONV_BWD_WEIGHT`, `CONV3D_*`) but dispatcher infrastructure is incomplete (ctypes bindings are stubs, `conv_utils.hpp` does not exist). +- [7] **(all dispatcher ops):** gfx908, gfx1100, and gfx1200 also have `warp_tile_combos` in `arch_specs.json` but are not shown in the matrix's 4 GPU columns. +- [8] **(all dispatcher ops):** `int4`, `fp32`, `fp64` are valid dispatcher data types (defined in `kernel_key.hpp` `DataType` enum) but have no dedicated matrix columns. + +### Dispatcher GEMM Configuration Detail + +#### Per-Variant Configuration + +| GEMM Variant | Pipelines | Schedulers | Epilogues | Element-wise Ops | Output Dtype | +|:---:|:---:|:---:|:---:|:---:|:---:| +| gemm_universal | mem, compv3, compv4 | intrawave, interwave | cshuffle, default | PassThrough | Same as input (fp8/bf8 -> fp16) | +| gemm_preshuffle | preshufflev2 | Auto | cshuffle | PassThrough | Same as input (fp8/bf8 -> fp16) | +| gemm_multi_d | mem, compv3, compv4 | intrawave, interwave | cshuffle, default | MultiDAdd, MultiDMultiply | Same as input (fp8/bf8 -> fp16) | + +#### Warp Tile Combinations per GPU + +| GPU | fp16 | bf16 | fp8 | bf8 | int8 | pk_fp4 | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| gfx1100 | 16x16x16 | 16x16x16 | -- | -- | 16x16x16 | -- | +| gfx1200 | 16x16x16 | 16x16x16 | 16x16x16 | 16x16x16 | 16x16x16 | -- | +| gfx1201 | 16x16x16 | 16x16x16 | 16x16x16 | 16x16x16 | 16x16x16 | -- | +| gfx908 | 32x32x8, 16x16x16, 32x32x16, 16x16x32 | 32x32x8, 16x16x16, 32x32x16, 16x16x32 | -- | -- | 32x32x16, 16x16x32 | -- | +| gfx90a | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 4x64x16, 64x4x16 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 4x64x16, 64x4x16 | 32x32x16, 32x32x32 | 32x32x16, 32x32x32 | 32x32x16, 16x16x32 | -- | +| gfx942 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 4x64x16, 64x4x16 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 4x64x16, 64x4x16 | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | 32x32x16, 16x16x32 | -- | +| gfx950 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 4x64x16, 64x4x16 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 4x64x16, 64x4x16 | 32x32x16, 32x32x32, 16x16x32, 16x16x64, 16x16x128, 32x32x64 | 32x32x16, 32x32x32, 16x16x32, 16x16x64, 16x16x128, 32x32x64 | 32x32x16, 16x16x32 | 16x16x128 | + +#### Preshuffle Warp Tile Combinations + +| GPU | fp16 | bf16 | fp8 | bf8 | int8 | +|:---:|:---:|:---:|:---:|:---:|:---:| +| gfx90a | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 64x4x16 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 64x4x16 | 32x32x16, 32x32x32 | 32x32x16, 32x32x32 | -- | +| gfx942 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 64x4x16 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 64x4x16 | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | 32x32x16, 32x32x32, 16x16x64, 16x16x32 | 16x16x32, 32x32x16 | +| gfx950 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 64x4x16 | 32x32x8, 16x16x16, 32x32x16, 16x16x32, 64x4x16 | 32x32x16, 32x32x32, 16x16x32, 16x16x64, 16x16x128, 32x32x64 | 32x32x16, 32x32x32, 16x16x64, 16x16x32, 16x16x128, 32x32x64 | -- | + +**Legend:** +- **CK Tile Kernel column:** First line is the kernel name. Lines prefixed with "engine:" show the dispatcher directory. Lines prefixed with "example:" show the CK Tile example directory under `example/ck_tile/`. +- **Green cell** (✅): CK Tile implementation exists **and** the dispatcher supports it. +- **Red cell** (❌): CK Tile implementation exists **but** the dispatcher does **not** support it. +- **Grey cell** (blank): No CK Tile implementation exists for this combination. + +**Layout codes:** Each 3-character layout code specifies the memory layout for tensors A, B, and C: +- `r` = row-major, `c` = column-major +- Example: `rcr` means A is row-major, B is column-major, C is row-major +- `gemm_multi_d` uses 4-character codes internally (e.g., `rcrr`) where the 4th character is the D tensor layout (always `r`). The matrix shows only the 3-character A/B/C portion. + +**Data type mapping per config label:** + +| Config Label | A (source) | B (source) | Acc | C (output) | +|:---:|:---:|:---:|:---:|:---:| +| fp16 | fp16 | fp16 | fp32 | fp16 | +| bf16 | bf16 | bf16 | fp32 | bf16 | +| int8 | int8 | int8 | int32 | int32 | +| fp8 | fp8 | fp8 | fp32 | fp16 | +| bf8 | bf8 | bf8 | fp32 | fp16 | +| fp6 | fp6 | fp6 | fp32 | fp32 | +| fp4 | fp16 or bf16 | fp4 | fp32 | fp16 or bf16 | + ## Troubleshooting ### Build Issues