mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Review changes
This commit is contained in:
@@ -73,6 +73,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
static_assert(!std::is_same_v<ADataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
@@ -103,9 +104,6 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
// TODO check KRepeat
|
||||
static constexpr index_t KRepeat = KPerBlock / GetSmemPackA();
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
@@ -190,19 +188,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
|
||||
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
|
||||
constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
|
||||
constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
|
||||
|
||||
constexpr auto num_dsread_stage1_a_mfma =
|
||||
(num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_stage1_b_mfma =
|
||||
(num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
constexpr auto num_dsread_stage3_a_mfma =
|
||||
(num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_stage3_b_mfma =
|
||||
(num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num -
|
||||
num_ds_read_inst_a / ds_read_a_mfma_rate -
|
||||
@@ -215,7 +208,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
// stage 1
|
||||
static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
@@ -224,14 +217,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
num_ds_read_inst_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
@@ -240,7 +233,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
num_ds_read_inst_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
@@ -273,7 +266,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
// stage 3
|
||||
static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
@@ -282,14 +275,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
num_ds_read_inst_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
@@ -298,7 +291,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
num_ds_read_inst_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
@@ -397,16 +390,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
acopy_dram_type& a_copy_dram_window = aWindows.at(I0);
|
||||
a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1);
|
||||
a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2);
|
||||
auto& a_copy_dram_window = aWindows.at(I0);
|
||||
auto& a_copy_lds_window = aWindows.at(I1);
|
||||
auto& a_lds_gemm_window = aWindows.at(I2);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
bcopy_dram_type& b_copy_dram_window = bWindows.at(I0);
|
||||
b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1);
|
||||
b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2);
|
||||
auto& b_copy_dram_window = bWindows.at(I0);
|
||||
auto& b_copy_lds_window = bWindows.at(I1);
|
||||
auto& b_lds_gemm_window = bWindows.at(I2);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
Reference in New Issue
Block a user