mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-24 08:54:34 +00:00
[CKTile] Fix MX GEMM: num_loop==3 dispatch, split-K, unsupported-shape guard (#6663) Three independent MX GEMM correctness bugs reported against example/ck_tile/42_mx_gemm (fp8xfp8, A=Row/B=Col) on MI350X, plus one host-side atomic-add accumulation bug in the example's repeat loop. - Pipeline (gemm_pipeline_ag_bg_cr_comp_async.hpp): BlockHasHotloop required num_loop > PrefetchStages, which let num_loop == 3 enter a hot loop that produced 5 gemm accumulations instead of 3 (K == 3*K_Tile, e.g. K=768, deterministically wrong). Require num_loop >= 4 instead: pre-pipeline + TailNumber::Three already totals exactly 3. - Kernel (gemm_mx_kernel.hpp): split-K was silently broken because GridSize did not thread k_batch into blockIdx.z and the scale tile windows were anchored at K=0 for every k_id. Every k_id >= 1 therefore read the wrong packed scales. Fix: * GridSize returns dim3(grid_x, 1, k_batch) (persistent and non-persistent). * MakeScaleA/BBlockWindows accept a k_elem_offset and translate it to a packed-scale K offset (also apply pad_tensor_view so OOB scale loads return zero, matching A/B padding). * operator() derives k_id from blockIdx.z, uses GetSplitKElemOffset (matches Underlying::SplitKBatchOffset's K1-aligned formula), and dispatches the epilogue with memory_operation_enum::atomic_add for k_batch > 1, set for k_batch == 1. Same fp16/bf16 even-vector-size guard as UniversalGemmKernel. * MakeCBlockWindows templated on DstInMemOp; unconditionally applies pad_tensor_view using kPadM/kPadN so partial trailing M/N tiles are handled correctly. - Compile- and runtime unsupported-shape guards (gemm_mx_kernel.hpp): add IsSupportedArgument and a static_assert for configurations that produce silent wrong results: * static_assert(!kPadK) -- the MX comp-async pipeline uses async_load_tile whose OOB check is per-vector-start, so a vector straddling the K pad boundary reads garbage. Until the async path learns per-element pad masking, reject kPadK at compile time. * Runtime: k_batch >= 1; M/N multiples of MPerBlock/NPerBlock when kPadM/kPadN are false; M >= MPerBlock and N >= NPerBlock always (CShuffleEpilogue cannot safely run with a single partial tile); K % (KPerBlock * k_batch) == 0; and for k_batch > 1, K must be a multiple of WarpTile_K * k_batch so every split lands on a packed-scale boundary. * All error paths log under CK_TILE_LOGGING with actionable messages. - Example (example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp): * Call Kernel::IsSupportedArgument up front and throw a clear runtime_error for rejected shapes (was silently launching an unsupported kernel). * Switch to launch_kernel_time_mask with a clear_gemm_output preprocess that zeroes C between iterations when k_batch > 1 (mirrors universal_gemm_invoker). Without this the default -warmup=50 -repeat=100 accumulated 150 atomic_adds into C after the kernel-side split-K fix. Tests (test/ck_tile/gemm_mx/): - Add MXfp8_GemmConfig16_PadMN (kPadM = kPadN = true). - test_mx_gemm_fp8.cpp: HotLoopTailNumLoopThree (K=768 regression), SplitK (k_batch=2,4 across full_k/partial_k paths), TestMxGemmFp8PadMN::{MNPaddingAligned, MPadding, NPadding, MNPadding} covering trailing partial tiles along M, N, or both. - Run(...) now takes k_batch. - packScalesMNxK: guard against OOB (mn, k) reads from src and initialise e8m0 bytes to the zero exponent (0x00) instead of the default-constructed NaN (0xFF), so padded lanes don't poison the packed int32_t shared with in-range lanes. - test_mx_gemm_instance.hpp: call IsSupportedArgument before launch. Verification on gfx950, ROCm 7.2.0: - ctest -R test_ck_tile_mx_gemm -> 100% (2/2). - Example sweep over the original bug-report shapes: all K-aligned shapes now validate correct (including 4096^3 sk=2 and the K=768 cases); all K=128 shapes cleanly rejected with the new error message instead of producing silent wrong results. Made-with: Cursor ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.