diff --git a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp index 5288acaef0..87aa1e06d1 100644 --- a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp @@ -458,7 +458,7 @@ int main(int argc, char* argv[]) std::size_t flop = std::size_t(2) * tokens * topk * N * K; std::size_t num_btype = - sizeof(A0DataType) * valid_tile_num * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_tile_num * N; + sizeof(A0DataType) * valid_tile_num * K + sizeof(B0DataType) / 2 * K * N * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; diff --git a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp index 3102cdf51b..0a3e8dca1c 100644 --- a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp @@ -422,7 +422,7 @@ int main(int argc, char* argv[]) std::size_t flop = std::size_t(2) * tokens * topk * N * K; std::size_t num_btype = - sizeof(A0DataType) * tokens * K * topk + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * tokens * N; + sizeof(A0DataType) * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + sizeof(EDataType) * tokens * N; float tflops = static_cast(flop) / 1.E9 / ave_time;