fix flatmm kernel for bigger size for fp16 datatype (#2302)

This commit is contained in:
Khushbu Agarwal
2025-06-10 11:13:40 -07:00
committed by GitHub
parent aed0f5880c
commit bd270fe4bc
6 changed files with 91 additions and 98 deletions

View File

@@ -75,7 +75,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) || defined(USING_MFMA_32x32x16)
#if defined(USING_MFMA_16x16x32) || defined(USING_MFMA_32x32x16)
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
@@ -92,7 +92,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
#endif
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
#if defined(USING_MFMA_16x16x32)
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read

View File

@@ -19,7 +19,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
#if defined(USING_MFMA_16x16x32)
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;