Addressing (Post Merge) code review comments for PR 1845 (#1883)

* Addressing code review comments.

* Addressing code review comments.

* Reorganized code for better readability.

* add ck_tile gemms for new types in CI

* fix jenkins syntax

* fix script syntax

* Add the test cases back

* Address the review comments

* Address review comments

* clang format

* Solve the merging issues

* Addressed the comments

* clang format

---------

Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
kylasa
2025-03-06 11:40:30 -08:00
committed by GitHub
parent c12fb0a624
commit 66c5f5b0b6
32 changed files with 511 additions and 245 deletions

View File

@@ -33,8 +33,21 @@ struct BaseGemmPipelineAgBgCrCompV3
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
if(BlockHasHotloop(num_loop))
{
return TailNumber::Full;
}
else
{
if(num_loop == 1)
{
return TailNumber::Odd;
}
else
{
return TailNumber::Even;
}
}
}
};
@@ -470,6 +483,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
@@ -478,12 +492,43 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return c_block_tile;
}

View File

@@ -143,7 +143,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
@@ -442,6 +442,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
Base::LocalPrefill(
b_copy_lds_window1, b_global_load_tile, b_element_func);
}
block_sync_lds();
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);

View File

@@ -26,6 +26,10 @@ struct GemmPipelineAGmemBGmemCRegV1
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
@@ -81,11 +85,21 @@ struct GemmPipelineAGmemBGmemCRegV1
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);