[GFX1250][MX GEMM] Unified FLATMM GroupedGemm Implementation
for MX Data Types (#8325)
## Motivation
Design and test a unified FLATMM GroupedGemm interface so that it
supports all MX FP8, FP6, and FP4 data types on both the gfx950 and
gfx1250 architectures and works seamlessly across these platforms.
## Technical Details
Implementation exposes Grouped Gemm interface for MX FLATMM and MX TDM
FLATMM pipelines.
## Test Plan
Add the following tests:
- ck_tile/grouped_gemm_mx/test_grouped_gemm_mx_flatmm_non_tdm.cpp
- ck_tile/grouped_gemm_mx/test_grouped_gemm_mx_flatmm_tdm.cpp
- ck_tile/flatmm/test_mx_flatmm_persistent.cpp
Verify on the gfx950 and gfx1250 architectures.
## Test Result
All tests pass. Verified on A0 hardware with rocm-7.14.0a20260517
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
[CK] Fix gfx950 AITER Sync Regressions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Summary
Fixes three gfx950 regressions in the AITER downstream CI that surfaced
after the internal/gfx1250 re-sync (ROCm/rocm-libraries#6978):
> **Companion aiter PR:** ROCm/aiter#3392 — host-side adaptations
(`Kernel::BlockSize()` `constexpr` drops, blockscale `KBatch=1` clamp)
plus the CK submodule bump used to validate these fixes together.
- **FlyDSL MoE AOT cache miss** — the AITER MoE tests run with
`check_aot_cache=True` and fail on any FlyDSL JIT cache miss, but the CI
never pre-compiles the FlyDSL MoE kernels, so gfx950 always misses.
Pre-compile them at the start of the AITER test stage.
- **`buffer.load.lds.v4i32` link error** — ROCm/rocm-libraries#6978
reintroduced a clang-version guard mapping
`llvm.amdgcn.raw.buffer.load.lds` to a `.v4i32`-suffixed name. That name
exists in no LLVM (the rsrc operand is a fixed, non-overloaded `<4 x
i32>`, so the intrinsic is never type-mangled), so gfx950 4-DWORD
direct-to-LDS (e.g. fp4 MoE bpreshuffle) fails to link with `lld:
undefined symbol: llvm.amdgcn.raw.buffer.load.lds.v4i32`. Use the
canonical plain name unconditionally.
- **mixed-precision flatmm warp-GEMM call** — ROCm/rocm-libraries#6978
generalized the scaled `WarpGemmImpl::operator()` from a fixed `<index_t
opselA, index_t opselB>` signature to a variadic `<typename... Params>`
one and updated the `mx_flatmm` pipeline to pass the op-selectors as
`OpSelA<>`/`OpSelB<>` types, but missed the mixed-precision flatmm
pipeline (`F8xMXF4`/`F16xMXF4`), which still passed raw integer
op-selectors. These no longer bind to `typename... Params` (`error: no
matching member function for call to 'operator()'`), breaking
compilation of the fp8/bf16 × fp4 cktile MoE gemm1 instances on gfx950
(aiter `test_moe_2stage`). Wrap the op-selectors in
`OpSelA<>`/`OpSelB<>`.
## Changes
- `Jenkinsfile`: pre-compile the FlyDSL MoE AOT cache (`python3
aiter/aot/flydsl/moe.py`) before the AITER tests.
- `include/ck/utility/amd_buffer_addressing_builtins.hpp` and
`include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp`: drop the
`__clang_major__` guard and always use
`__asm("llvm.amdgcn.raw.buffer.load.lds")`. The plain name is the
canonical one for all sizes including the gfx950 16-byte form, as the
upstream LLVM gfx950 tests confirm.
-
`include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp`:
wrap the warp-GEMM op-selectors in `OpSelA<>`/`OpSelB<>` at the five
call sites, matching the `mx_flatmm` pipeline.
## Test plan
Validated via CI.
[CK_TILE][GFX1250] Enable MX GEMM FLATMM with ASYNC
## Motivation
Enables MX GEMM FLATMM pipeline on gfx1250. The pipeline uses an async
load instruction for tensor A, which complements the existing MX GEMM
FLATMM pipeline with TDM load. At this time, only FLATMM MX pipelines
are enabled on gfx1250.
## Technical Details
The existing gfx950 implementation was extended to support gfx1250
architecture. All three MX FP data types are supported across the two
ASICs.
It should be noted that while the TDM pipeline uses an emulated
32x32x128 warp-tile instruction, the present submission relies on the
built-in 16x16x128 instruction, called 4 times per warp.
## Test Plan
Existing `test/ck_tile/flatmm` tests were extended to cover new gfx1250
functionality.
To help facilitate the testing in development,
`example/ck_tile/18_flatmm/script/smoke_test_mx.sh` script was
introduced to verify various combinations of supported data types and
pipeline versions.
## Test Result
The present submission is expected to work on both gfx950 and gfx1250
hardware for all reasonable sizes and all MX FP8/FP6/FP4 data types.
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
- [x] Relies on #6978 and should only be merged after the changes are
merged to the `develop`.
[CK] add composable kernel support on gfx1250 (#6978)
## Motivation
Add composable kernel support on gfx1250.
## Technical Details
<!-- Explain the changes along with any relevant GitHub links. -->
## Test Plan
<!-- Explain any relevant testing done to verify this PR. -->
## Test Result
<!-- Briefly summarize test outcomes. -->
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: Qun Lin <qlin@amd.com>
Co-authored-by: jialuo12_amdeng <jia.luo@amd.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: hsivasun_amdeng <haresh.sivasuntharampillai@amd.com>
* solve compiler issue
* solve the gfx950 mfma shuffle regression
* refactor jenkinsfile to handle arch name better
* [CK TILE] set divisor to count of thread along k dimension
* fix the compiler error
* solve degradation
* Finish the multiplies fix
* fix the scales
* solve compilation error
* solve the composes
* solve the error of tile sweeper
* fix the test and example
* fix for gfx950
---------
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Replace `decltype(TailHandler<>(...)){}` with direct function call
to fix compilation error when return type is void.
Co-authored-by: Yi DING <yi.ding@amd.com>
* something khushbu can help with
* v1 v2 works with flatmm develop
* v0 v1 v2 numerical error gone
* Fixing numerical error, and interchange preshuffle configs to match with flatmm
* Refactor GEMM pipeline configurations and integrate preshuffle support
- Updated preshuffle pipeline definitions to include multiple versions (V1, V2, V3).
- Changed the pipeline constant from CK_TILE_PIPELINE_PRESHUFFLE to CK_TILE_PIPELINE_PRESHUFFLE_V3 in relevant configurations.
- Removed obsolete code and comments
* clang format
* fix vectorloadsize bug
* add the Preshuffle3
* update kwarp calculation in gemm utils
* update vector size A and B correctly in V2 pipeline; Added few more changes to align with dteng's branch
* fix: add CK_GFX950_SUPPORT macro for gfx950 detection
* default disable rotating buffer
* docs(CHANGELOG): update changelog for rocm 7.0
* Revert "docs(CHANGELOG): update changelog for rocm 7.0"
This reverts commit 2bc16fff84.
* Remove unused Preshuffle V3 pipeline and related code; update gemm function to use Preshuffle V2; clean up comments and formatting in various files.
* revert example/ck_tile/flatmm to its original state
* remove comment added by second author
* switch to xor ALDSDescriptor
* modify the MakeALdsDescriptor()
* temporary profiling script
* getting rid of line marker compiler error
* UniversalWeightPreshufflePipelineAgBgCrPolicy now derives from UniversalGemmBasePolicy
* add a minor fix for the config
* typo fix
* Fix formatting in lambda function for WeightPreshufflePipelineAGmemBGmemCRegV2
* revert change in include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp
* revert change in include/ck_tile/core/arch/amd_buffer_addressing.hpp
* reenable the GemmSpatiallyLocalTilePartitioner
* make GemmConfigPreshuffle_1 for v1 pipeline, GemmConfigPreshuffle_2 for v2 pipeline
* remove hardcoded true for preshuffle bool template argument
* rename script
* remove gemm_profilie.sh script
* merge conflict resolve
* clang formatted
* typo fix
* Remove duplicate include of block_gemm_areg_bsmem_creg_v2r1.hpp in gemm.hpp
* Remove commented-out code in UniversalWeightPreshufflePipelineAgBgCrPolicy
* Fix missing newline at end of file in run_gemm_example.inc
* Remove unused barrier call in BlockWeightPreshuffleASmemBSmemCRegV1
* addressing review comments
* removing debug code
* addressing review comments
* Revert "addressing review comments"
This reverts commit 29c45192ba.
* updating tile_engine code
* addressing review comments
---------
Co-authored-by: amd-khushbu <khuagarw@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
* ck_tile kernel for gemm with groupwise quantized A or B tensor.
This change introduces new pipelines with Intrawave scheduler and block gemm primitives that loads the scale tensor to registers to perform dequantization post MFMA on C tensor in registers.
Scale tensor data, AQ/BQ is spliced across threads in registers and not stored in LDS.
Current support is for the following combinations, but it should be fairly straightforward to extend support to more formats.
1. fp8, fp8 -> f32
2. bf8, bf8 -> f32
3. i4, fp8 -> f32
4. i4, bf8 -> f32
Group size can go down to as low as K length of underlying WarpGemm primitive.
For Gemm problems with quantized B tensor, this change also introduces preliminary support for flatmm pipeline which loads B tensor directly into registers.
* [Block Scale Gemm] Only run gemm quant examples on __gfx94__
- Only run gemm quant examples on __gfx94__ for usage of
`v_cvt_pk_fp8_f32`
- Format the code
* [Block Scale Gemm] Remove Bquant Gemm BlockScale
This cleanup is in preparation for future development of bquant. By
isolating Aquant-related code, we can streamline the codebase and make
it easier to add and maintain bquant functionality in subsequent
updates.
* [Block Scale Gemm] Format code with clang-format-12
The latest clang-format (v19) in ROCm 7.0 generate different result than
clang-format-12 which is used in CK CI.
Format code with clang-format-12 for consistency.
* [Block Scale Gemm] Split the k direction loop
- Split the k direction loop in block_universal_gemm_as_quant_bs_cr.hpp
to make the logic clearer.
- Disable C transposition.
* [Block Scale Gemm] Move block scale gemm example to 38_block_scale_gemm
* [Block Scale Gemm] Update copyright
* test
* Add TailHandler
* Move TileDistributionEncodingPatternAQ
* Refactor
* refactor
* fix bug
* fix bug
* help solve the PR comment
* Format the code
* [Block Scale Gemm] Add unit tests
* [Block Scale Gemm] Add support to 16x16x32 MFMA
- Add support to 16x16x32 MFMA
- Fix a bug when exchange data crossing lanes
---------
Co-authored-by: Vijay Krishnamoorthy <vjkrish@meta.com>
Co-authored-by: Cong MA <congma13@ctr2-alola-ctrl-01.amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
* Initial commit
* Adding new tile partitioner to flatmm
* intermediate changes
* debugging kernels
* Updating flatmm example to universal gemm example
* updated flatmm kernel to run via gemmKernel
* update universal gemm to incorporate flatmm
* debug
* Fix flatmm call
* Fixing other kernels and tests for API changes
* clang formatted
* fixing gemm tests
* added test for flatmm and simplify kernel arguments
* adding flatmm test
* fix test for flatmm
* simplify gemm kernel with flatmm
* remove flatmm related files
* addressing review comments and code clean up
* resolving empty file
* resolving empty file
* clang formatted
* addressing review comments
* enable persistent kernel for flatmm
* reverted the removed files for flatmm
* reverted the removed files for flatmm
* changed flatmm to weightPReshuffle; removed the _1 added in teh faltmm example
* some more renames
* clang formatted
* [CK_TILE] Refine fp8 in flatmm
1. Replace USING_MFMA_16x16x32 & USING_MFMA_16x16x32 with constexpr
2. Add an additional const check to avoid build error in HotLoopScheduler
3. Refine shuffleb to support both tile 32x32 and 16x16
4. Support command option -init
5. Move Gemm warp defintion to a separate struct
* fix clang format
* fix clang format
* keep default bhavior unchanged (warp tile = 16x16)
* fix tile engine build error
* fix a typo in codegen_utils.py
* address review comments
* address review comments
---------
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
* sync with function interface of cshuffleepiloge,fix flatmm build fail
* move code from solin/flatmm which add mfma16*16*32fp8 and optimize flatmm
---------
Co-authored-by: solin <bingzhou@amd.com>
* add ck tile examples to package
* Update jenkinsfile
* fix for jenkinsfile
* fix for building ck tile code on non gfx9
* compile ck tile examples only for gfx94
* include ck tile examples in all target
* fix for basic gemm UseStructuredSparsity
* Update CMakeLists.txt
* Update gemm_pipeline_problem.hpp
* add targets to rocm install
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>