mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Merge remote-tracking branch 'origin/moe_bs_stage1_dev' into moe_merge_v3_bs_for_aiter
This commit is contained in:
@@ -161,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
@@ -169,7 +169,7 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
|
||||
@@ -409,9 +409,13 @@ struct DeviceMoeGemmBlockScale
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
|
||||
}
|
||||
else if constexpr(MPerBlock == 64)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: v3 only support 128x128x1288.\n");
|
||||
printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user