From 4fdde500ebb0046bcdb771b0f7eeaf835f2d2dd7 Mon Sep 17 00:00:00 2001 From: felix Date: Sun, 26 Oct 2025 08:24:24 +0000 Subject: [PATCH] change tilem --- .../moe_gemm2_xdl_mx_fp4_bns.cpp | 54 +++++++------------ 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 2670468c4b..09df3381da 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -144,7 +144,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 32; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -156,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MPerBlock, 128, KPerBlock, 16, 16, 16, 16, - 4, 4, + 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + 2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -171,16 +171,11 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape - constexpr ck::index_t sorted_tile_num = 13; - constexpr ck::index_t valid_tile_num = sorted_tile_num; - ck::index_t sorted_size = sorted_tile_num * MPerBlock; - ck::index_t valid_size = valid_tile_num * MPerBlock; - - ck::index_t N = 6144; - ck::index_t K = 4096; - ck::index_t experts = 8; - ck::index_t tokens = 832; - ck::index_t topk = 2; + ck::index_t N = 256; + ck::index_t K = 7168; + ck::index_t experts = 256; + ck::index_t tokens = 64; + ck::index_t topk = 8; if(argc == 1) { @@ -216,6 +211,10 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); }; + ck::index_t sorted_tile_num = experts > tokens * topk ? experts : tokens * topk; + ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t StrideA = K; ck::index_t StrideB = K; ck::index_t StrideE = N; @@ -231,34 +230,18 @@ int main(int argc, char* argv[]) Tensor max_token_id(HostTensorDescriptor({1})); max_token_id.mData[0] = valid_size; // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; - int eids[sorted_tile_num]{}; for(int i = 0; i < sorted_tile_num; i++) { - if(i < valid_tile_num) - { - eids[i] = (i * experts) / valid_tile_num; - } - else - { - eids[i] = 3; - } + expert_ids.mData[i] = i / (valid_tile_num / experts); } - for(int i = 0; i < sorted_tile_num; i++) - { - expert_ids.mData[i] = eids[i]; - } - if(tokens * topk > valid_size) - { - printf("err config, tokens * topk > valid_size\n"); - exit(-1); - } - int token_per_tile = tokens * topk / valid_tile_num; + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; - if(tile_off < token_per_tile) + if(tile_off < token_per_tile && tokenid < tokens * topk) { sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); tokenid++; @@ -457,10 +440,11 @@ int main(int argc, char* argv[]) std::size_t flop = std::size_t(2) * tokens * topk * N * K + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + int valid_expert = tokens * topk < experts? tokens * topk : experts; std::size_t num_btype = - sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(A0DataType) * tokens * K * topk/ 2 + sizeof(B0DataType) * K * N * valid_expert / 2 + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + - sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + sizeof(XDataType) * K / ScaleBlockSize * N * valid_expert + sizeof(EDataType) * tokens * N; float tflops = static_cast(flop) / 1.E9 / ave_time;