diff --git a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp index d56c24708a..a0d9c6fcdb 100644 --- a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp @@ -152,7 +152,8 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -#if 0 +#if 1 +static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); @@ -168,7 +169,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 64, MPerBlock, 16, KPerBlock, + 256, MPerBlock, 128, KPerBlock, AK1, BK1, MNPerXDL, MNPerXDL, MXDLPerWave, 1, @@ -208,12 +209,12 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 6144; ck::index_t K = 8192; - ck::index_t experts = 1; - ck::index_t sorted_tile_num = 1; + ck::index_t experts = 8; + ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_size = MPerBlock; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; - // ck::index_t tokens = 128; - ck::index_t tokens = 16; + ck::index_t tokens = 128; + // ck::index_t tokens = 16; if(argc == 1) {