mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
update MX moe GEMM1 hotloopscheduling
This commit is contained in:
@@ -159,7 +159,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, BlockSize,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
16, 16,
|
||||
2, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
@@ -237,6 +237,12 @@ int main(int argc, char* argv[])
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
|
||||
@@ -201,18 +201,18 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3<BlockGemmPipelineSche
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2 * 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num * 2;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
|
||||
|
||||
constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
|
||||
constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
|
||||
constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2;
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize;
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2;
|
||||
|
||||
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
|
||||
Reference in New Issue
Block a user