Users/tlakshma/ck/tile engine develop ## Motivation This PR adds multiple new GPU kernel benchmarking operations to the CK Tile Engine, expanding its coverage of GEMM-family operations: - **gemm_multi_abd**: GEMM with multiple A, B, and D tensors, enabling epilogue patterns such as scale/bias fusion. - **batched_contraction**: Batched tensor contraction supporting multi-dimensional batch (G), M, N, and K dimensions, targeting workloads where the contraction indices span more than one logical axis. - **mx_gemm**: MX-format GEMM with microscaling (e8m0) scale tensors. - **gemm_rowcolquant**: Block-scale GEMM with row/column quantization. - **gemm_tensor_quant**: Block-scale GEMM with tensor quantization. - **grouped_gemm_rowcolquant**: Grouped GEMM with row/column quantization. - **grouped_gemm_tensorquant**: Grouped GEMM with tensor quantization. - **batched_gemm**: Batched GEMM benchmarking support. ## Technical Details ### gemm_multi_abd - New subdirectory: tile_engine/ops/gemm/gemm_multi_abd/ - CMakeLists.txt follows the same individual-target pattern as gemm_universal / gemm_multi_d. - gemm_multi_abd_instance_builder.py subclasses GemmKernelBuilder from the shared gemm_instance_builder.py. - gemm_multi_abd_benchmark.py delegates to the shared GemmBenchmark parent class. - Configs: default_config.json, default_ci_config.json, user_provided_config.json. - Supported GPU targets: gfx90a, gfx942, gfx950, gfx1201. ### batched_contraction - New subdirectory: tile_engine/ops/gemm/batched_contraction/ - Extends GemmKernelBuilder via BatchedContractionKernelBuilder, adding num_dim_g, num_dim_m, num_dim_n, num_dim_k, num_d_tensors, and elementwise_function parameters. - Layout string uses 3-character encoding (A+B+E), e.g. rcr. - Self-contained benchmark sweep driver (batched_contraction_benchmark.py) with JSON/CSV export and best-kernel selection. - Supported GPU targets: gfx90a, gfx942, gfx950. ### mx_gemm - New subdirectory: tile_engine/ops/gemm/mx_gemm/ - Supports MX-format (e8m0) microscaling for A and B scale tensors. ### block_scale_gemm (gemm_rowcolquant, gemm_tensor_quant) - New subdirectory: tile_engine/ops/gemm/block_scale_gemm/ - gemm_rowcolquant: row/column quantization epilogue. - gemm_tensor_quant: tensor-level quantization epilogue. ### grouped_gemm_quant (grouped_gemm_rowcolquant, grouped_gemm_tensorquant) - New subdirectory: tile_engine/ops/gemm/grouped_gemm_quant/ - grouped_gemm_rowcolquant: grouped GEMM with row/column quantization. - grouped_gemm_tensorquant: grouped GEMM with tensor quantization. ### batched_gemm - New subdirectory: tile_engine/ops/gemm/batched_gemm/ - Batched GEMM benchmark support wired into the sampling/active-op lists. All new ops are registered in op_weights.json for budget allocation and wired into the active-op sampling lists in CMakeLists.txt. ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
8.7 KiB
CK Tile operation support by data type, layout, and GPU target:
| Op | CK Tile Kernel | fp16 | fp8 | bf16 | bf8 | int8 | fp4 | fp6 | rcr | rrr | ccr | crr | 90a | 942 | 950 | 1201 | Op Weight |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| GEMM | gemm_universal [1][2] engine: gemm_universal/ example: 03_gemm/ |
✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 0.0834 | ||
| GEMM | gemm_multi_d [3] engine: gemm_multi_d/ example: 19_gemm_multi_d/ |
✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | 0.0834 | ||||||
| GEMM | gemm_preshuffle [4] engine: gemm_preshuffle/ |
✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | 0.0834 | ||||||
| GEMM | streamk_gemm [5][6][7] engine: gemm_streamk/ example: 40_streamk_gemm/ |
✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ||||
| GEMM | batched_gemm [11] engine: batched_gemm/ example: 16_batched_gemm/ |
✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 0.0833 | |||||||||
| GEMM | batched_contraction example: 41_batched_contraction/ |
✅ | ✅ | ✅ | ✅ | ✅ | ❌ | 0.0833 | |||||||||
| GEMM | block_scale_gemm/gemm_rowcolquant [9] engine: block_scale_gemm/gemm_rowcolquant/ example: 38_block_scale_gemm/ |
✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | 0.0833 | ||||||||
| GEMM | block_scale_gemm/gemm_tensor_quant [9] engine: block_scale_gemm/gemm_tensor_quant/ example: 38_block_scale_gemm/ |
✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | 0.0833 | |||||
| GEMM | flatmm example: 18_flatmm/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | |||||
| GEMM | gemm_multi_abd example: 22_gemm_multi_abd/ |
✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 0.0833 | |||||||||
| GEMM | gemm_quant | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | |||||||||
| GEMM | grouped_gemm [10] engine: grouped_gemm/ example: 17_grouped_gemm/ |
✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | 0.0834 | |||||
| GEMM | grouped_gemm_quant/grouped_gemm_rowcolquant engine: grouped_gemm_quant/grouped_gemm_rowcolquant/ |
✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 0.0833 | |||||||||
| GEMM | grouped_gemm_quant/grouped_gemm_tensorquant engine: grouped_gemm_quant/grouped_gemm_tensorquant/ |
✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 0.0833 | |||||||||
| GEMM | mx_gemm [12] engine: gemm/mx_gemm/ |
✅ | ✅ | ✅ | ✅ | ✅ | ❌ | 0.0833 | |||||||||
| Reduce | multi_reduce2d [8] engine: reduce/ example: 05_reduce/ |
✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ||||||||||
| Reduce | reduce2d example: 05_reduce/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Attention | fmha engine: fmha/ example: 01_fmha/ |
✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ||||||||
| Attention | sparse_attn example: 50_sparse_attn/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Activation | softmax | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Activation | topk_softmax example: 09_topk_softmax/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Conv | grouped_conv example: 20_grouped_convolution/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Data Move | batched_transpose example: 35_batched_transpose/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | |||||||||
| Data Move | image_to_column example: 04_img2col/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Data Move | permute example: 06_permute/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Elementwise | elementwise example: 21_elementwise/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| MoE | fused_moe example: 15_fused_moe/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Norm | add_rmsnorm2d_rdquant example: 11_add_rmsnorm2d_rdquant/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Norm | layernorm2d example: 02_layernorm2d/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Norm | norm_reduce | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Norm | rmsnorm2d example: 10_rmsnorm2d/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ||||||||||
| Pooling | pooling example: 36_pooling/ |
❌ | ❌ | ❌ | ❌ | ❌ | |||||||||||
| Quant | smoothquant example: 12_smoothquant/ |
❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
Legend:
- CK Tile Kernel column: First line is the kernel name. Lines prefixed with "engine:" show the tile engine directory under
ops/. Lines prefixed with "example:" show the CK Tile example directory underexample/ck_tile/. - Green cell (✅): CK Tile implementation exists and the tile engine supports it.
- Red cell (❌): CK Tile implementation exists but the tile engine does not support it.
- Grey cell (blank): No CK Tile implementation exists for this combination.
- Op Weight column: Sampling budget weight assigned to each op in
sampling/op_weights.json. Ops without a weight are not part of the daily sampling tier.
Notes:
- All CK Tile GEMM and reduce kernels are architecturally generic (no compile-time GPU guards). The gfx filtering in the tile engine is a validation/testing scope decision, not a code limitation.
- [1] gemm_universal: CMake defaults to
fp8;fp16. Building bf16/bf8 requires-DGEMM_UNIVERSAL_DATATYPE="fp16;fp8;bf16;bf8". - [2] gemm_universal: CK Tile supports int8 GEMM (with int32 output) but the tile engine has no int8 configuration.
- [3] gemm_multi_d: CK Tile kernel is type-generic but example and tile engine are fp16-only. Adding other types requires new tile engine configurations.
- [4] gemm_preshuffle: Only supports rcr layout (A=row, B=column, C=row) due to the pre-shuffle data format requirement.
- [5] streamk_gemm: CK Tile supports bf16 and bf8 for streamk, but the tile engine has no default tile configs for them.
- [6] streamk_gemm: Builder and default configs support all 4 layouts, but CMake defaults to
rcronly. Building others requires-DGEMM_STREAMK_LAYOUT="rcr;rrr;ccr;crr". - [7] streamk_gemm: CK Tile kernels have no arch-specific guards; tile engine filtering is pending validation for gfx950/gfx1201.
- [8] multi_reduce2d: CK Tile's reduce example supports bf16 input but the tile engine only configures fp16. The reduce kernel adapts to wave32/wave64 at runtime via
is_wave32(). - [9] block_scale_gemm/gemm_rowcolquant: Supports row-column quantized GEMM with fp8/bf8 inputs. Only rcr layout is supported. Not validated on gfx90a.
- [10] grouped_gemm: Tile engine filters to gfx942, gfx950, and gfx12-generic (gfx1201) targets only. Supports fp16 and fp8 datatypes with all 4 layouts.
- [11] batched_gemm: Tile engine supports fp16 with rcr layout only. The engine filters to gfx90a, gfx942, gfx950, and gfx1201 targets.
- [12] mx_gemm: Microscaling GEMM supporting fp4 and fp8 MX datatypes with rcr layout. Validated on gfx942 and gfx950 only.
- Reduce operations do not use matrix layouts.
Layout codes: Each layout code specifies the memory layout of tensors A, B, and C as row-major (r) or column-major (c). For example, rcr means A is row-major, B is column-major, and C is row-major. For gemm_multi_d, the instance builder uses 4-character codes (e.g., rcrr) where the 4th character specifies the D tensor layout; in the table above, the 3-character A/B/C portion is shown since the D layout is always row-major (r) for all supported configurations.
Data type mapping: The column labels (fp16, fp8, bf16, bf8, int8, fp4, fp6) refer to the input configuration label passed to the tile engine or CK Tile example. Each label determines the actual types used for the source tensors (A, B), accumulator, and output tensor (C). For 16-bit and 8-bit float types, A and B use the label type, the accumulator is fp32, and the output type C matches the input type for fp16 and bf16 but is promoted to fp16 for fp8 and bf8 since 8-bit precision is insufficient for output storage. int8 uses int32 for both accumulation and output. fp4 is a mixed-precision weight type where B is fp4 and A uses the activation type (fp16 or bf16). fp6 is used by the microscaling (MX) flatmm pipeline where both A and B are fp6 with fp32 accumulation and fp32 output.
Data type mapping per config label:
| Config Label | A (source) | B (source) | Acc | C (output) |
|---|---|---|---|---|
| fp16 | fp16 | fp16 | fp32 | fp16 |
| bf16 | bf16 | bf16 | fp32 | bf16 |
| int8 | int8 | int8 | int32 | int32 |
| fp8 | fp8 | fp8 | fp32 | fp16 |
| bf8 | bf8 | bf8 | fp32 | fp16 |
| fp6 | fp6 | fp6 | fp32 | fp32 |
| fp4 | fp16 or bf16 | fp4 | fp32 | fp16 or bf16 |
For gemm_multi_d, the D tensors (D0, D1) use the same type as the config label (fp16).