[CK_TILE] Refine fp8 support in flatmm (#2239)

* [CK_TILE] Refine fp8 in flatmm

1. Replace USING_MFMA_16x16x32 & USING_MFMA_16x16x32 with constexpr
2. Add an additional const check to avoid build error in HotLoopScheduler
3. Refine shuffleb to support both tile 32x32 and 16x16
4. Support command option -init
5. Move Gemm warp defintion to a separate struct

* fix clang format

* fix clang format

* keep default bhavior unchanged (warp tile = 16x16)

* fix tile engine build error

* fix a typo in codegen_utils.py

* address review comments

* address review comments

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
linqunAMD
2025-06-25 16:07:45 +08:00
committed by GitHub
parent 50fad03524
commit 37e1a27537
10 changed files with 313 additions and 198 deletions

View File

@@ -59,14 +59,23 @@ struct GemmHostArgs
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
void* e_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
index_t stride_E;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};