mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Addressing (Post Merge) code review comments for PR 1845 (#1883)
* Addressing code review comments. * Addressing code review comments. * Reorganized code for better readability. * add ck_tile gemms for new types in CI * fix jenkins syntax * fix script syntax * Add the test cases back * Address the review comments * Address review comments * clang format * Solve the merging issues * Addressed the comments * clang format --------- Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -167,7 +167,7 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
@@ -275,7 +275,7 @@ struct GemmKernel
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -295,7 +295,7 @@ struct GemmKernel
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -407,7 +407,7 @@ struct GemmKernel
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
@@ -671,7 +671,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
|
||||
@@ -694,7 +694,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
|
||||
@@ -33,8 +33,21 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
return TailNumber::Full;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -470,6 +483,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -478,12 +492,43 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
{
|
||||
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
@@ -442,6 +442,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window1, b_global_load_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
@@ -26,6 +26,10 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
@@ -81,11 +85,21 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user