Merge remote-tracking branch 'origin/moe_bs_stage1_dev' into moe_merge_v3_bs_for_aiter

This commit is contained in:
OscarXu
2025-05-19 19:34:42 +08:00
2 changed files with 7 additions and 3 deletions

View File

@@ -161,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
@@ -169,7 +169,7 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso
MPerBlock, 128, 128,
16, 16,
16, 16,
4, 4,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,

View File

@@ -409,9 +409,13 @@ struct DeviceMoeGemmBlockScale
{
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
}
else if constexpr(MPerBlock == 64)
{
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
}
else
{
printf("Faild: v3 only support 128x128x1288.\n");
printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n");
}
}
}