Review changes

This commit is contained in:
Aleksander Dudek
2025-06-23 05:52:13 -05:00
parent 654ff5a320
commit cc4122c813

View File

@@ -73,6 +73,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static_assert(!std::is_same_v<ADataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
@@ -103,9 +104,6 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
// TODO check KRepeat
static constexpr index_t KRepeat = KPerBlock / GetSmemPackA();
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
@@ -190,19 +188,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
constexpr auto num_dsread_stage1_a_mfma =
(num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_stage1_b_mfma =
(num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
constexpr auto num_dsread_stage3_a_mfma =
(num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_stage3_b_mfma =
(num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num -
num_ds_read_inst_a / ds_read_a_mfma_rate -
@@ -215,7 +208,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
// stage 1
static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
@@ -224,14 +217,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
num_ds_read_inst_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
@@ -240,7 +233,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
num_ds_read_inst_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -273,7 +266,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
// stage 3
static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
@@ -282,14 +275,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
num_ds_read_inst_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) {
ignore = i;
if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
@@ -298,7 +291,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
num_ds_read_inst_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -397,16 +390,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
acopy_dram_type& a_copy_dram_window = aWindows.at(I0);
a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1);
a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2);
auto& a_copy_dram_window = aWindows.at(I0);
auto& a_copy_lds_window = aWindows.at(I1);
auto& a_lds_gemm_window = aWindows.at(I2);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
bcopy_dram_type& b_copy_dram_window = bWindows.at(I0);
b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1);
b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2);
auto& b_copy_dram_window = bWindows.at(I0);
auto& b_copy_lds_window = bWindows.at(I1);
auto& b_lds_gemm_window = bWindows.at(I2);
// Block GEMM
auto block_gemm = BlockGemm();