diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index 1a838ced24..ed98e49bd5 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -433,11 +433,12 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K + - std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; - std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) / 2 * K * N * 2 * experts + - sizeof(EDataType) * valid_tile_num * N; + std::size_t flop = + std::size_t(2) * tokens * N * 2 * K + std::size_t(2) * tokens * N * K / ScaleBlockSize; + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * K + sizeof(XDataType) * K * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; float tflops = static_cast(flop) / 1.E9 / ave_time; diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp index 4e111d7d88..08addd7862 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -560,12 +560,13 @@ int main(int argc, char* argv[]) // not result correct here because output buf not setzero float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K + - std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + std::size_t flop = + std::size_t(2) * tokens * N * 2 * K + std::size_t(2) * tokens * N * K / ScaleBlockSize; - std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * K * topk + - sizeof(B0DataType) / 2 * K * N * 2 * experts + - sizeof(EDataType) * tokens * N; + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) / 2 * tokens * K + sizeof(XDataType) / 2 * K * N * 2 * experts + + sizeof(EDataType) * tokens * N; float tflops = static_cast(flop) / 1.E9 / ave_time;