updated the codes

This commit is contained in:
mtgu0705
2025-06-04 02:04:28 -05:00
parent 5117e99822
commit 40ed20a30d
5 changed files with 176 additions and 327 deletions

View File

@@ -143,7 +143,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -151,15 +151,15 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 64,
MPerBlock, 32, KPerBlock,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -170,14 +170,14 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 2;
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 2;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
@@ -418,7 +418,7 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
#if 1
#if 0
printf("a0_t_k_k:\n");
// for(int t = 0; t < tokens; ++t)
// {
@@ -671,7 +671,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
#if 1
#if 0
printf("e_t_n_device_result:\n");
for(int t = 0; t < tokens; ++t)
{

View File

@@ -203,9 +203,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
? HotLoopInstList::B_LDS_Read_Inst_Num
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
@@ -243,29 +240,18 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
constexpr auto mfma_stages_more =
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
if constexpr(i < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_a)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_a)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
@@ -274,23 +260,15 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_a)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma < num_dswrite_per_issue_b)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
@@ -392,14 +370,14 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
ABlockBuffer& a_block_bufs,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
BBlockBuffer& b_block_bufs,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
@@ -427,8 +405,8 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
@@ -476,22 +454,11 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefetch 1
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
@@ -503,7 +470,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
@@ -525,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
@@ -537,6 +504,13 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
});
});
// Global prefetch 2
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
@@ -549,13 +523,13 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
do
{
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
@@ -652,22 +626,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(
@@ -702,10 +674,11 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
// t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
// t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
// k = 0 k = 1
block_sync_lds();
// __builtin_amdgcn_s_waitcnt(3952);
// block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
@@ -719,7 +692,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_block_bufs(scale_mem_buf),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
@@ -743,7 +716,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_block_bufs(scale_mem_buf),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
@@ -801,10 +774,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
@@ -848,22 +817,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
@@ -885,11 +852,12 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
});
});
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
@@ -902,7 +870,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
@@ -925,7 +893,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
@@ -980,22 +948,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
@@ -1062,22 +1028,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
@@ -1092,69 +1056,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
#if 0
printf(
"blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: "
"%d, ikxdl: %d, a_thread_vec=<%.2f, %.2f, %.2f, %.2f>, "
"b_thread_vec=<%.2f, %.2f, %.2f, %.2f>, a_scale=%08x, "
"b_scale=%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
imxdl.value,
inxdl.value,
ikxdl.value,
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<0>{}]
.unpack(Number<0>{})),
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<0>{}]
.unpack(Number<1>{})),
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<1>{}]
.unpack(Number<0>{})),
type_convert<float>(
a_thread_vec
.template AsType<ComputeTypeA>()[Number<1>{}]
.unpack(Number<1>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<0>{}]
.unpack(Number<0>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<0>{}]
.unpack(Number<1>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<1>{}]
.unpack(Number<0>{})),
type_convert<float>(
b_thread_vec
.template AsType<ComputeTypeB>()[Number<1>{}]
.unpack(Number<1>{})),
*(reinterpret_cast<const uint32_t*>(&(
a_scale_thread_vec
.template AsType<AScaleDataType>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(&(
b_scale_thread_vec
.template AsType<BScaleDataType>()[Number<0>{}]))),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<0>{}]),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<1>{}]),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<2>{}]),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<3>{}]));
#endif
});
});
});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -52,8 +52,7 @@ template <typename ThreadGroup,
index_t DstVectorDim,
index_t ScalarPerVector,
typename IndexType,
index_t GatherDim = 1,
bool SrcXor = true>
index_t GatherDim = 1>
struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
@@ -67,31 +66,15 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// static constexpr index_t AK0 = SrcDesc{}.GetLength(I0);
// static constexpr index_t M = SrcDesc{}.GetLength(I1);
// static constexpr index_t AK1 = SrcDesc{}.GetLength(I2);
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto wave_thread_cluster_lengths =
Sequence<ThreadClusterLengths{}.At(I0),
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
1>{};
static constexpr auto wave_cluster_lengths =
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// CK_PRINT<decltype(thread_single_load_size)>();
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
@@ -119,8 +102,12 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
#if defined(__gfx950__)
int num_contiguous_dwords = 4;
bool is_contiguous = true;
#else
int num_contiguous_dwords = 1;
#endif
bool is_contiguous = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(is_contiguous)
{
@@ -128,7 +115,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
}
if(thread_slice_lengths[nDim - i - 1] > 1)
{
CK_PRINT<Number<thread_slice_lengths[nDim - i - 1]>>();
is_contiguous = false;
}
});
@@ -189,6 +175,25 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
constexpr auto wave_cluster_lengths = generate_sequence_v2(
[&](auto i) {
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
{
return Number<ThreadGroup::GetNumOfThread() / 64>{};
}
else
{
return I1;
}
},
Number<nDim>{});
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
@@ -276,52 +281,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, true);
#if 0
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
printf("blkx: %u, blky: %u, tid: %u, red_id: %d src: %d (cal: %d, gather: %d), "
"dst_offset: "
"%d, a_dst_buffer=<0x%08x, 0x%08x, 0x%08x, 0x%08x>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
static_cast<int>(ordered_dst_access_idx[Number<GatherDim>{}]),
src_offset,
src_coord_xor_.GetOffset(),
gather_offset,
dst_offset,
// *(reinterpret_cast<const uint32_t*>(&(dst_buf[dst_offset + 0].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 0 + 16 * threadIdx.x].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 4 + 16 * threadIdx.x].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 8 + 16 * threadIdx.x].data))),
*(reinterpret_cast<const uint32_t*>(
&(dst_buf[dst_offset + 12 + 16 * threadIdx.x].data))));
#else
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
printf("blkx: %u, blky: %u, tid: %u, thread_slice_lengths=<%d, %d, %d>, "
"src_coord_xor_=<%d, "
"%d, %d>, read_id: %d "
"src: %d (cal: %d, gather: %d)\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
thread_slice_lengths[0],
thread_slice_lengths[1],
thread_slice_lengths[2],
src_coord_xor_.GetIndex().At(I0),
src_coord_xor_.GetIndex().At(I1),
src_coord_xor_.GetIndex().At(I2),
static_cast<int>(ordered_dst_access_idx[Number<GatherDim>{}]),
src_offset,
src_coord_xor_.GetOffset(),
gather_offset);
#endif
constexpr auto move_src_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
@@ -432,8 +391,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
SrcCoord src_coord_xor_;

View File

@@ -256,31 +256,18 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Full>;
RunKernel(kernel);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
@@ -310,26 +297,15 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
}
else
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Full>;
RunKernel(kernel);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{

View File

@@ -129,8 +129,8 @@ template <typename ALayout,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t ScaleBlockSize,
index_t BlockSize,
index_t ScaleBlockSize, // Scaling block size
index_t BlockSize, // Thread block size
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
@@ -193,25 +193,33 @@ struct GridwiseMoeGemmMXBNS
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MXdlPack = 2;
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
//> KPack is at least the k_per_blk of selected mfma
//
// Should be a multiple of k_per_blk.
// TODO: Move this to blockwise pipeline base
// KPack in packed data types for pk A/B
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;
using mfma_selector = MfmaSelector<ComputeTypeA,
using mfma_selector = MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>;
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize);
static constexpr index_t KPack =
math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk / APackedSize);
// static constexpr index_t NumTokens = 1;
static constexpr index_t SortedTileSize = MPerBlock;
@@ -362,12 +370,28 @@ struct GridwiseMoeGemmMXBNS
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return a_grid_desc_ak0_m_ak1;
const auto a_grid_desc_permuted = transform_tensor_descriptor(
a_grid_desc_ak0_m_ak1,
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(MPad, AK0Number)),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
const auto a_grid_desc = transform_tensor_descriptor(
a_grid_desc_permuted,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)),
make_pass_through_transform(MPad),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_grid_desc;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
@@ -439,8 +463,9 @@ struct GridwiseMoeGemmMXBNS
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
GemmSpec != GemmSpecialization::Default),
"f4x2_pk_t does not support padding");
(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MPadding)),
"f4x2_pk_t does not support K padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
@@ -1368,6 +1393,10 @@ struct GridwiseMoeGemmMXBNS
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
"A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
#if 0
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
@@ -2266,20 +2295,6 @@ struct GridwiseMoeGemmMXBNS
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
#if 0
printf("blkx: %u, blky: %u, tidx: %u,AMThreads: %d, token_pos: %d, gather_offsets:<%d, %d, "
"%d, %d>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
AMThreads,
token_pos,
gather_offsets[Number<0>{}],
gather_offsets[Number<1>{}],
gather_offsets[Number<2>{}],
gather_offsets[Number<3>{}]);
#endif
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(