[GEMM] Fix MFMA condition checks

This commit is contained in:
Clement Lin
2025-03-30 14:02:30 +08:00
parent 5dd8e4ae0c
commit de9385ba51
6 changed files with 3 additions and 21 deletions

View File

@@ -17,8 +17,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if defined(USING_MFMA_32x32x_8x2)
#pragma message ("mfma m32 n32 k16")
#if defined(NAIVE_IMPLEMENTATION)
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -31,8 +30,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
#elif defined(NAIVE_IMPLEMENTATION)
#pragma message ("mfma m32 n32 k8")
#elif defined(USING_MFMA_32x32x_8x2)
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -47,7 +45,6 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
}
#elif defined(USING_MFMA_16x16x16)
#pragma message("mfma m16 n16 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -61,7 +58,6 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
#elif defined(USING_MFMA_16x16x_16x2)
#pragma message("mfma m16 n16 k32")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)

View File

@@ -115,7 +115,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
#if defined(ENABLE_PREFETCH)
#pragma message ("prefetch")
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
@@ -189,7 +188,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
#else
#pragma message ("non-prefetch")
// non-prefetch
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);

View File

@@ -24,7 +24,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPack = 8;
#if defined(NAIVE_IMPLEMENTATION)
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
@@ -39,7 +38,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_K_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -54,7 +52,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_MN_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -69,7 +66,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
#pragma message ("BANK_CONFLICT: XOR")
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
@@ -123,7 +119,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPack = 8;
#if defined(PADDING_K_FIRST) || defined(NAIVE_IMPLEMENTATION)
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
@@ -138,7 +133,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_K_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -153,7 +147,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_MN_FIRST)
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -168,7 +161,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
#pragma message ("BANK_CONFLICT: XOR")
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto DataTypeSize = sizeof(BDataType);

View File

@@ -50,7 +50,7 @@ int main(int argc, char* argv[])
#if defined(KERNEL_A)
printf("*** KernelA test *** \n");
printf(" --> Using mfma_16x16x(8x2)\n");
printf(" --> Using mfma_32x32x(8x2)\n");
#elif defined(KERNEL_B)
printf("*** KernelB test *** \n");
printf(" --> Using mfma_16x16x16\n");
@@ -127,7 +127,6 @@ int main(int argc, char* argv[])
constexpr ck_tile::index_t kBlockSize = 256;
#ifdef ADJUST_BLOCK_TILE_SHAPE
#pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size")
constexpr ck_tile::index_t kGemmMPerBlock = 128;
constexpr ck_tile::index_t kGemmKPerBlock = 64;
#else

View File

@@ -87,7 +87,6 @@ struct Gemm
index_t N0)
{
#if defined(ENABLE_CACHE_AWARE_WG_SCH)
#pragma message ("Cache-aware work group sch")
return [=](index_t block_1d_id) {
constexpr index_t M01 = 4;
constexpr index_t GroupNum = 8;

View File

@@ -59,13 +59,11 @@ struct GridGemm
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
#ifndef ENABLE_INSTRUCTION_SCH
#pragma message ("disable instruction scheduling")
// Block GEMM pipeline w/o instruction scheduling
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
#else
#pragma message ("enable instruction scheduling")
// Block GEMM pipeline w/ instruction scheduling
static constexpr index_t M_Tile = 128;
static constexpr index_t N_Tile = 128;