align v3 gufusion pipeline

This commit is contained in:
lalala-sh
2025-04-30 01:35:49 +00:00
17 changed files with 874 additions and 770 deletions

View File

@@ -22,3 +22,10 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1)
endif()
endforeach()
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
list(APPEND GEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
target_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
target_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})

View File

@@ -140,14 +140,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
128, 128, 128,
256, 256, 128,
16, 16,
32, 32,
2, 2,
16, 16,
16, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
2, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -157,11 +157,11 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = 8;
static constexpr ck::index_t NXDLPerWave = 1;
static constexpr ck::index_t MPerBlock = 256;
static constexpr ck::index_t MXDLPerWave = 16;
static constexpr ck::index_t NXDLPerWave = 4;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t NPerBlock = 256;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = false;
@@ -183,14 +183,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM
// mn_perxdl
MNPerXDL, MNPerXDL,
// mn_xdlperwave
MXDLPerWave, NXDLPerWave,
MXDLPerWave, NXDLPerWave,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
// clang-format on
@@ -268,6 +268,7 @@ int main(int argc, char* argv[])
{
expert_ids.mData[i] = i / (valid_tile_num / experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;

View File

@@ -123,11 +123,11 @@ using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 256;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t MXDLPerWave = 16;
static constexpr ck::index_t NXDLPerWave = 4;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t NPerBlock = 256;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
@@ -136,11 +136,12 @@ static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool MulRoutedWeight = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// TODO: Epilogue performance issue. AtomicAdd lose 15~20% performance compare with Set.
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool MulRoutedWeight = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
@@ -164,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
@@ -186,8 +187,8 @@ int main(int argc, char* argv[])
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 133;
ck::index_t valid_tile_num = 128;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 128;
@@ -247,11 +248,11 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
}
if(tokens * topk > valid_size)
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -123,6 +123,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
using Base::I0;
using Base::I1;
using Base::I2;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -157,9 +158,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -185,163 +186,145 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename Stage>
__device__ static constexpr auto HotLoopScheduler(Stage stage)
__device__ static constexpr auto HotLoopScheduler()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
if constexpr(stage.value == 0)
{
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
constexpr auto num_total_stages = MRepeat;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_stages_more =
(num_buffer_load_inst_a + num_buffer_load_inst_b) -
math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
(num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_b_stages =
buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
? num_buffer_load_inst_b / buffer_load_perstage_more
: (buffer_load_stages_more +
(num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
buffer_load_perstage_less);
constexpr auto buffer_load_a_stages =
num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
{
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_b)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_b)))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 2)
{
constexpr auto staged_num_mfma_per_buffer_load_a =
math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a;
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
// A global read + A local write
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
ds_write_issue_point)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
ds_write_issue_point)))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
else
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
__builtin_amdgcn_sched_barrier(0);
}
});
}
template <typename Stage>
__device__ static constexpr auto EpilogueScheduler_1(Stage stage)
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b =
MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
@@ -527,12 +510,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
@@ -571,32 +554,23 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
else if constexpr(m0.value == 1)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
}
else if constexpr(m0.value == 2)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
}
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
@@ -617,11 +591,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
@@ -635,7 +609,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
@@ -643,48 +617,87 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
});
});
if constexpr(m0.value == MRepeat - 1)
if constexpr(m0.value == MRepeat - 2)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
HotLoopScheduler(m0);
HotLoopScheduler();
});
};
@@ -697,26 +710,19 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
// tail
if constexpr(TailNum == TailNumber::Even)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
}
else if constexpr(m0.value == MRepeat - 1)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
}
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
@@ -730,7 +736,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
@@ -745,43 +751,78 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value == MRepeat - 1)
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == MRepeat - 1)
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
EpilogueScheduler_1(m0);
});
HotLoopScheduler();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -810,7 +851,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
@@ -820,22 +861,28 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
if constexpr(m0.value != (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
}
});
HotLoopScheduler();
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == TailNumber::Odd)
{
@@ -873,18 +920,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
});
});
if constexpr(m0.value != (MRepeat - 1))
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
}
});
}
@@ -913,7 +963,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
};
} // namespace ck

View File

@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -153,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -280,12 +281,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -348,12 +351,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -411,12 +417,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -495,7 +503,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -152,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -281,12 +282,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(I0));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(I0));
});
});
});
@@ -320,12 +323,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_buf));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_buf));
});
});
});
@@ -391,12 +397,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_reg));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
});
});
@@ -445,12 +453,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_bufs(local_read_reg));
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
});
});
@@ -539,7 +549,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -123,6 +123,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
using Base::I0;
using Base::I1;
using Base::I2;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -156,9 +157,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -184,298 +185,230 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename Stage>
__device__ static constexpr auto HotLoopScheduler(Stage stage)
__device__ static constexpr auto HotLoopScheduler()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value == 0)
{
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
{
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 2)
{
constexpr auto staged_num_mfma_per_buffer_load_a =
math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a;
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
});
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
}
template <typename Stage>
__device__ static constexpr auto EpilogueScheduler_1(Stage stage)
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value == 0)
{
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
{
#if 0
constexpr auto staged_num_ds_write_a_per_ds_read_a =
num_ds_write_inst_a / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a;
// A local write
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) {
ignore = idswrite_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
});
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_stages = MRepeat;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_mfma_per_issue_more = math::integer_divide_ceil(
num_mfma_inst, num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_mfma_per_issue_less = math::integer_divide_floor(
num_mfma_inst, num_buffer_load_inst_a + num_buffer_load_inst_b);
// Insert more mfmas between bufferloads
constexpr auto num_stage1_bufferloads =
num_mfma_inst -
(num_buffer_load_inst_a + num_buffer_load_inst_b) * num_mfma_per_issue_less;
constexpr auto num_stage1_mfma = num_mfma_per_issue_more * num_stage1_bufferloads;
// Insert less mfmas between bufferloads
// constexpr auto num_stage2_mfma = num_mfma_inst - num_stage1_mfma;
constexpr auto buffer_load_issue_point = 0;
constexpr auto ds_write_issue_point_stage1 = num_mfma_per_issue_more >= 3 ? 1 : 0;
constexpr auto ds_write_issue_point_stage2 = num_mfma_per_issue_less >= 3 ? 1 : 0;
static_for<0, num_mfma_inst, 1>{}([&](auto i) {
constexpr auto current_buffer_load_issue =
i < num_stage1_mfma
? (i / num_mfma_per_issue_more)
: (num_stage1_bufferloads + (i - num_stage1_mfma) / num_mfma_per_issue_less);
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// Group num_mfma_perstage num_ds_read_a_perstage
// Hide A lds rd issue latency at begining of each stage
if constexpr((i % num_mfma_perstage) >=
(num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
// Schedule VMEM access instruction distributed evenly in the loop
// Hide B/A global rd issue latency
if constexpr(((i < num_stage1_mfma) &&
(i % num_mfma_per_issue_more == buffer_load_issue_point)) ||
((i >= num_stage1_mfma) &&
((i - num_stage1_mfma) % num_mfma_per_issue_less ==
buffer_load_issue_point)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
// Hide A lds wr issue latency
if constexpr((current_buffer_load_issue >= num_buffer_load_inst_b) &&
((((i < num_stage1_mfma) &&
(i % num_mfma_per_issue_more == ds_write_issue_point_stage1)) ||
((i >= num_stage1_mfma) &&
((i - num_stage1_mfma) % num_mfma_per_issue_less ==
ds_write_issue_point_stage2))) &&
(((i < num_stage1_mfma) &&
((i / num_mfma_per_issue_more - num_buffer_load_inst_b) < num_ds_write_inst_a)) ||
((i >= num_stage1_mfma) &&
((i - num_stage1_mfma) / num_mfma_per_issue_less +
num_stage1_bufferloads - num_buffer_load_inst_b) < num_ds_write_inst_a))))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
#elif 1
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_total_stages = MRepeat;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_stages_more =
(num_buffer_load_inst_a + num_buffer_load_inst_b) -
math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
(num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_b_stages =
buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
? num_buffer_load_inst_b / buffer_load_perstage_more
: (buffer_load_stages_more +
(num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
buffer_load_perstage_less);
constexpr auto buffer_load_a_stages =
num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_b)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_b)))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
#endif
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
}
__device__ static constexpr auto EpilogueScheduler_2()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
// A global read + A local write
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
ds_write_issue_point)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
ds_write_issue_point)))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
#endif
}
template <bool HasMainLoop,
@@ -537,13 +470,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(I0, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(I0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, 2, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
// Initialize C
@@ -558,26 +495,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
else if constexpr(m0.value == 1)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
}
else if constexpr(m0.value == 2)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
}
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
@@ -613,49 +542,88 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value == MRepeat - 1)
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
HotLoopScheduler(m0);
});
HotLoopScheduler();
};
LoopFunc(I0, I1);
@@ -667,20 +635,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
// tail
if constexpr(TailNum == TailNumber::Even)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
}
else if constexpr(m0.value == MRepeat - 1)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
}
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
@@ -707,36 +669,72 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value == MRepeat - 1)
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
EpilogueScheduler_1(m0);
});
HotLoopScheduler();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -764,25 +762,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value != (MRepeat - 1))
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
}
});
HotLoopScheduler();
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == TailNumber::Odd)
{
@@ -813,18 +817,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value != (MRepeat - 1))
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
}
});
}
@@ -841,7 +848,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -58,6 +58,11 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t KGroup =
((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);

View File

@@ -205,7 +205,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -177,8 +177,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto mfma_cycle = NPerXDL == 16 ? 32 : 64;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -179,7 +179,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -178,7 +178,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -167,9 +167,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
@@ -209,7 +210,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{
return math::integer_divide_ceil(K, KLane * KPack);
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
}
__host__ __device__ static auto CalculateKPadded(index_t K)
@@ -351,7 +352,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1228,7 +1229,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1668,7 +1669,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds

View File

@@ -188,7 +188,9 @@ struct GridwiseMoeGemm
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
// static constexpr index_t NumTokens = 1;
@@ -249,7 +251,7 @@ struct GridwiseMoeGemm
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{
return math::integer_divide_ceil(K, KLane * KPack);
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
}
__host__ __device__ static auto CalculateKPadded(index_t K)
@@ -391,7 +393,7 @@ struct GridwiseMoeGemm
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1301,7 +1303,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2012,7 +2014,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2081,23 +2083,24 @@ struct GridwiseMoeGemm
c_thread_buf_up,
num_k_block_main_loop);
}
else
{
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
}
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
}
// shuffle C and write out
{

View File

@@ -90,6 +90,17 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
static constexpr index_t C_MFMA_Inst_Num =
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
static constexpr index_t C_MFMA_Inst_Cycle = []() {
if constexpr(NPerXDL == 16)
{
return KPerXDL == 128 ? 32 : 16;
}
else if constexpr(NPerXDL == 32)
{
return KPerXDL == 64 ? 64 : 32;
}
}();
static constexpr auto Print()
{
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
@@ -103,7 +114,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n"
"%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num,
@@ -113,6 +124,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num,
C_MFMA_Inst_Cycle,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,