diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index e1a0e12f8f..e5471fc259 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -161,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>; #else -static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< Row, Col, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -169,7 +169,7 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso MPerBlock, 128, 128, 16, 16, 16, 16, - 4, 4, + 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 5682b38918..2f8e854514 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -409,9 +409,13 @@ struct DeviceMoeGemmBlockScale { hsa_name = std::string("moe_bs_stage2_v3_128x128x128"); } + else if constexpr(MPerBlock == 64) + { + hsa_name = std::string("moe_bs_stage2_v3_64x128x128"); + } else { - printf("Faild: v3 only support 128x128x1288.\n"); + printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n"); } } }