mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
[GEMM] Fix MFMA condition checks
This commit is contained in:
@@ -17,8 +17,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if defined(USING_MFMA_32x32x_8x2)
|
||||
#pragma message ("mfma m32 n32 k16")
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -31,8 +30,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message ("mfma m32 n32 k8")
|
||||
#elif defined(USING_MFMA_32x32x_8x2)
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -47,7 +45,6 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
}
|
||||
|
||||
#elif defined(USING_MFMA_16x16x16)
|
||||
#pragma message("mfma m16 n16 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -61,7 +58,6 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif defined(USING_MFMA_16x16x_16x2)
|
||||
#pragma message("mfma m16 n16 k32")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
|
||||
@@ -115,7 +115,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message ("prefetch")
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -189,7 +188,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
#else
|
||||
#pragma message ("non-prefetch")
|
||||
// non-prefetch
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
@@ -24,7 +24,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message ("BANK_CONFLICT: K_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -39,7 +38,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_K_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -54,7 +52,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_MN_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -69,7 +66,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
|
||||
#pragma message ("BANK_CONFLICT: XOR")
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
@@ -123,7 +119,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if defined(PADDING_K_FIRST) || defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message ("BANK_CONFLICT: K_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -138,7 +133,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_K_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -153,7 +147,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_MN_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
@@ -168,7 +161,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
|
||||
#pragma message ("BANK_CONFLICT: XOR")
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
|
||||
@@ -50,7 +50,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
#if defined(KERNEL_A)
|
||||
printf("*** KernelA test *** \n");
|
||||
printf(" --> Using mfma_16x16x(8x2)\n");
|
||||
printf(" --> Using mfma_32x32x(8x2)\n");
|
||||
#elif defined(KERNEL_B)
|
||||
printf("*** KernelB test *** \n");
|
||||
printf(" --> Using mfma_16x16x16\n");
|
||||
@@ -127,7 +127,6 @@ int main(int argc, char* argv[])
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
#ifdef ADJUST_BLOCK_TILE_SHAPE
|
||||
#pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size")
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 128;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 64;
|
||||
#else
|
||||
|
||||
@@ -87,7 +87,6 @@ struct Gemm
|
||||
index_t N0)
|
||||
{
|
||||
#if defined(ENABLE_CACHE_AWARE_WG_SCH)
|
||||
#pragma message ("Cache-aware work group sch")
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 4;
|
||||
constexpr index_t GroupNum = 8;
|
||||
|
||||
@@ -59,13 +59,11 @@ struct GridGemm
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
#ifndef ENABLE_INSTRUCTION_SCH
|
||||
#pragma message ("disable instruction scheduling")
|
||||
// Block GEMM pipeline w/o instruction scheduling
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
#else
|
||||
#pragma message ("enable instruction scheduling")
|
||||
// Block GEMM pipeline w/ instruction scheduling
|
||||
static constexpr index_t M_Tile = 128;
|
||||
static constexpr index_t N_Tile = 128;
|
||||
|
||||
Reference in New Issue
Block a user