fix a bug that K_Warp_Tile not match in two files.

This commit is contained in:
AMD-dteng
2025-06-18 19:51:57 -05:00
parent bb5a520324
commit 07b579d1dd
2 changed files with 64 additions and 52 deletions

View File

@@ -27,46 +27,18 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
#if defined(USING_MFMA_16x16x128)
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Tile = GemmConfig<BDataType>::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig<BDataType>::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig<BDataType>::K_Tile;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 128;
#endif
// This part comes from the Codegen
#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 128;
constexpr ck_tile::index_t M_Warp = GemmConfig<BDataType>::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig<BDataType>::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig<BDataType>::K_Warp;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8)
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 128;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 8;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
#endif
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig<BDataType>::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig<BDataType>::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig<BDataType>::K_Warp_Tile;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;