enable f8 new mfma for preshuffle gemm. found some vectorizer issue even with slp flag

This commit is contained in:
aska-0096
2025-04-23 15:53:15 +00:00
parent 25bb0d2fee
commit d6e2dd92fe
13 changed files with 212 additions and 126 deletions

View File

@@ -16,3 +16,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1)
endif()
endforeach()
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16")
list(APPEND GEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})

View File

@@ -140,10 +140,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
128, 128, 128,
32, 128, 128,
16, 16,
32, 32,
2, 2,
1, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,

View File

@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -153,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -280,12 +281,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -348,12 +351,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -411,12 +417,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -495,7 +503,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -152,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -281,12 +282,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(I0));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(I0));
});
});
});
@@ -320,12 +323,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_buf));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_buf));
});
});
});
@@ -391,12 +397,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_reg));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
});
});
@@ -445,12 +453,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_reg));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
});
});
@@ -539,7 +549,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -123,6 +123,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
using Base::I0;
using Base::I1;
using Base::I2;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -156,9 +157,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -538,12 +539,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(I0, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(I0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(I0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(I0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
// Initialize C
@@ -618,39 +621,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
@@ -712,25 +729,41 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
@@ -767,14 +800,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
if constexpr(m0.value != (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 1>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
@@ -816,12 +856,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
if constexpr(m0.value != (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 1>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
@@ -841,7 +886,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -57,6 +57,11 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t KGroup =
((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);

View File

@@ -205,7 +205,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -177,8 +177,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto mfma_cycle = NPerXDL == 16 ? 32 : 64;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -179,7 +179,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -178,7 +178,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -167,9 +167,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
@@ -209,7 +210,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{
return math::integer_divide_ceil(K, KLane * KPack);
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
}
__host__ __device__ static auto CalculateKPadded(index_t K)
@@ -351,7 +352,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1228,7 +1229,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1668,7 +1669,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds

View File

@@ -90,6 +90,17 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
static constexpr index_t C_MFMA_Inst_Num =
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
static constexpr index_t C_MFMA_Inst_Cycle = []() {
if constexpr(NPerXDL == 16)
{
return KPerXDL == 128 ? 32 : 16;
}
else if constexpr(NPerXDL == 32)
{
return KPerXDL == 64 ? 64 : 32;
}
}();
static constexpr auto Print()
{
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
@@ -103,7 +114,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n"
"%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num,
@@ -113,6 +124,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num,
C_MFMA_Inst_Cycle,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,