mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Add gemm2 v3 64x128x128
This commit is contained in:
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_64x128x128.co
Executable file
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_64x128x128.co
Executable file
Binary file not shown.
@@ -161,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
@@ -169,7 +169,7 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
|
||||
@@ -409,9 +409,13 @@ struct DeviceMoeGemmBlockScale
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
|
||||
}
|
||||
else if constexpr(MPerBlock == 64)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: v3 only support 128x128x1288.\n");
|
||||
printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user