mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
updated the codes
This commit is contained in:
@@ -143,7 +143,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
@@ -151,15 +151,15 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 64,
|
||||
MPerBlock, 32, KPerBlock,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -170,14 +170,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 2;
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 2;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
@@ -418,7 +418,7 @@ int main(int argc, char* argv[])
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
printf("a0_t_k_k:\n");
|
||||
// for(int t = 0; t < tokens; ++t)
|
||||
// {
|
||||
@@ -671,7 +671,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
printf("e_t_n_device_result:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
|
||||
@@ -203,9 +203,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
@@ -243,29 +240,18 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
constexpr auto mfma_stages_more =
|
||||
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
|
||||
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
if constexpr(i < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
@@ -274,23 +260,15 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma < num_dswrite_per_issue_b)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
@@ -392,14 +370,14 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
ABlockBuffer& a_block_bufs,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
BBlockBuffer& b_block_bufs,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
@@ -427,8 +405,8 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
@@ -476,22 +454,11 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefetch 1
|
||||
// Local prefetch 1, sync the async load
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
@@ -503,7 +470,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_block_bufs(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
@@ -525,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_block_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
@@ -537,6 +504,13 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
});
|
||||
});
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -549,13 +523,13 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
@@ -652,22 +626,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_a = typename vector_type< //
|
||||
ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
using mfma_input_type_b = typename vector_type< //
|
||||
ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_a = typename vector_type< //
|
||||
AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b = typename vector_type< //
|
||||
BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(
|
||||
@@ -702,10 +674,11 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
// t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
|
||||
// t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
|
||||
// k = 0 k = 1
|
||||
block_sync_lds();
|
||||
// __builtin_amdgcn_s_waitcnt(3952);
|
||||
// block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0,
|
||||
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
|
||||
@@ -719,7 +692,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_block_bufs(scale_mem_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
@@ -743,7 +716,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_block_bufs(scale_mem_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
@@ -801,10 +774,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
@@ -848,22 +817,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_a = typename vector_type< //
|
||||
ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
using mfma_input_type_b = typename vector_type< //
|
||||
ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_a = typename vector_type< //
|
||||
AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b = typename vector_type< //
|
||||
BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, 0));
|
||||
@@ -885,11 +852,12 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
@@ -902,7 +870,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_block_bufs(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
@@ -925,7 +893,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_block_bufs(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
@@ -980,22 +948,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_a = typename vector_type< //
|
||||
ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
using mfma_input_type_b = typename vector_type< //
|
||||
ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_a = typename vector_type< //
|
||||
AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b = typename vector_type< //
|
||||
BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, 0));
|
||||
@@ -1062,22 +1028,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_a = typename vector_type< //
|
||||
ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
using mfma_input_type_b = typename vector_type< //
|
||||
ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_a = typename vector_type< //
|
||||
AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b = typename vector_type< //
|
||||
BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, 0));
|
||||
@@ -1092,69 +1056,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
|
||||
b_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
#if 0
|
||||
printf(
|
||||
"blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: "
|
||||
"%d, ikxdl: %d, a_thread_vec=<%.2f, %.2f, %.2f, %.2f>, "
|
||||
"b_thread_vec=<%.2f, %.2f, %.2f, %.2f>, a_scale=%08x, "
|
||||
"b_scale=%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
imxdl.value,
|
||||
inxdl.value,
|
||||
ikxdl.value,
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<0>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<0>{}]
|
||||
.unpack(Number<1>{})),
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<1>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
a_thread_vec
|
||||
.template AsType<ComputeTypeA>()[Number<1>{}]
|
||||
.unpack(Number<1>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<0>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<0>{}]
|
||||
.unpack(Number<1>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<1>{}]
|
||||
.unpack(Number<0>{})),
|
||||
type_convert<float>(
|
||||
b_thread_vec
|
||||
.template AsType<ComputeTypeB>()[Number<1>{}]
|
||||
.unpack(Number<1>{})),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
a_scale_thread_vec
|
||||
.template AsType<AScaleDataType>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(
|
||||
b_scale_thread_vec
|
||||
.template AsType<BScaleDataType>()[Number<0>{}]))),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<0>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<1>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<2>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<3>{}]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -52,8 +52,7 @@ template <typename ThreadGroup,
|
||||
index_t DstVectorDim,
|
||||
index_t ScalarPerVector,
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1,
|
||||
bool SrcXor = true>
|
||||
index_t GatherDim = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
@@ -67,31 +66,15 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
// static constexpr index_t AK0 = SrcDesc{}.GetLength(I0);
|
||||
// static constexpr index_t M = SrcDesc{}.GetLength(I1);
|
||||
// static constexpr index_t AK1 = SrcDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
static constexpr auto wave_thread_cluster_lengths =
|
||||
Sequence<ThreadClusterLengths{}.At(I0),
|
||||
ThreadClusterLengths{}.At(I1) * 64 / ThreadGroup::GetNumOfThread(),
|
||||
1>{};
|
||||
static constexpr auto wave_cluster_lengths =
|
||||
Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
// CK_PRINT<decltype(thread_single_load_size)>();
|
||||
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
|
||||
|
||||
@@ -119,8 +102,12 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
|
||||
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
|
||||
// elements = 64 consecutive DWORDs.
|
||||
#if defined(__gfx950__)
|
||||
int num_contiguous_dwords = 4;
|
||||
bool is_contiguous = true;
|
||||
#else
|
||||
int num_contiguous_dwords = 1;
|
||||
#endif
|
||||
bool is_contiguous = true;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(is_contiguous)
|
||||
{
|
||||
@@ -128,7 +115,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
}
|
||||
if(thread_slice_lengths[nDim - i - 1] > 1)
|
||||
{
|
||||
CK_PRINT<Number<thread_slice_lengths[nDim - i - 1]>>();
|
||||
is_contiguous = false;
|
||||
}
|
||||
});
|
||||
@@ -189,6 +175,25 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
constexpr auto wave_cluster_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
|
||||
{
|
||||
return Number<ThreadGroup::GetNumOfThread() / 64>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
|
||||
constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() / 64));
|
||||
|
||||
@@ -276,52 +281,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
|
||||
dst_buf, src_offset, dst_offset, true);
|
||||
|
||||
#if 0
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
printf("blkx: %u, blky: %u, tid: %u, red_id: %d src: %d (cal: %d, gather: %d), "
|
||||
"dst_offset: "
|
||||
"%d, a_dst_buffer=<0x%08x, 0x%08x, 0x%08x, 0x%08x>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
static_cast<int>(ordered_dst_access_idx[Number<GatherDim>{}]),
|
||||
src_offset,
|
||||
src_coord_xor_.GetOffset(),
|
||||
gather_offset,
|
||||
dst_offset,
|
||||
// *(reinterpret_cast<const uint32_t*>(&(dst_buf[dst_offset + 0].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 0 + 16 * threadIdx.x].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 4 + 16 * threadIdx.x].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 8 + 16 * threadIdx.x].data))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(dst_buf[dst_offset + 12 + 16 * threadIdx.x].data))));
|
||||
|
||||
#else
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
printf("blkx: %u, blky: %u, tid: %u, thread_slice_lengths=<%d, %d, %d>, "
|
||||
"src_coord_xor_=<%d, "
|
||||
"%d, %d>, read_id: %d "
|
||||
"src: %d (cal: %d, gather: %d)\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
thread_slice_lengths[0],
|
||||
thread_slice_lengths[1],
|
||||
thread_slice_lengths[2],
|
||||
src_coord_xor_.GetIndex().At(I0),
|
||||
src_coord_xor_.GetIndex().At(I1),
|
||||
src_coord_xor_.GetIndex().At(I2),
|
||||
static_cast<int>(ordered_dst_access_idx[Number<GatherDim>{}]),
|
||||
src_offset,
|
||||
src_coord_xor_.GetOffset(),
|
||||
gather_offset);
|
||||
#endif
|
||||
|
||||
constexpr auto move_src_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
@@ -432,8 +391,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
static constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
SrcCoord src_coord_;
|
||||
SrcCoord src_coord_xor_;
|
||||
|
||||
@@ -256,31 +256,18 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
@@ -310,26 +297,15 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
|
||||
@@ -129,8 +129,8 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t ScaleBlockSize,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockSize, // Scaling block size
|
||||
index_t BlockSize, // Thread block size
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
@@ -193,25 +193,33 @@ struct GridwiseMoeGemmMXBNS
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
static constexpr auto is_scale_mfma = true;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto MXdlPack = 2;
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
|
||||
//> KPack is at least the k_per_blk of selected mfma
|
||||
//
|
||||
// Should be a multiple of k_per_blk.
|
||||
// TODO: Move this to blockwise pipeline base
|
||||
// KPack in packed data types for pk A/B
|
||||
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
static constexpr auto is_scale_mfma = true;
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA,
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>;
|
||||
static constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize);
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk / APackedSize);
|
||||
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
@@ -362,12 +370,28 @@ struct GridwiseMoeGemmMXBNS
|
||||
// pad M, but not K
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
|
||||
make_right_pad_transform(M, MPad - M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
const auto a_grid_desc_permuted = transform_tensor_descriptor(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(MPad, AK0Number)),
|
||||
make_pass_through_transform(AK1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
const auto a_grid_desc = transform_tensor_descriptor(
|
||||
a_grid_desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)),
|
||||
make_pass_through_transform(MPad),
|
||||
make_pass_through_transform(AK1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
return a_grid_desc;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
@@ -439,8 +463,9 @@ struct GridwiseMoeGemmMXBNS
|
||||
GemmSpec != GemmSpecialization::Default),
|
||||
"pk_i4_t does not support padding");
|
||||
static_assert(!(is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t> &&
|
||||
GemmSpec != GemmSpecialization::Default),
|
||||
"f4x2_pk_t does not support padding");
|
||||
(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MPadding)),
|
||||
"f4x2_pk_t does not support K padding");
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
@@ -1368,6 +1393,10 @@ struct GridwiseMoeGemmMXBNS
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
|
||||
static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
|
||||
"A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
|
||||
|
||||
#if 0
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
@@ -2266,20 +2295,6 @@ struct GridwiseMoeGemmMXBNS
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
#if 0
|
||||
printf("blkx: %u, blky: %u, tidx: %u,AMThreads: %d, token_pos: %d, gather_offsets:<%d, %d, "
|
||||
"%d, %d>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
AMThreads,
|
||||
token_pos,
|
||||
gather_offsets[Number<0>{}],
|
||||
gather_offsets[Number<1>{}],
|
||||
gather_offsets[Number<2>{}],
|
||||
gather_offsets[Number<3>{}]);
|
||||
#endif
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
|
||||
Reference in New Issue
Block a user