diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 159fb21851..f0850dd514 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -158,7 +158,7 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t NPerBlock = 256; static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index b3289dc58a..3188ba142c 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -125,9 +125,9 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 256; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t MXDLPerWave = 16; static constexpr ck::index_t NXDLPerWave = 4; -static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t NPerBlock = 256; static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); @@ -168,7 +168,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;