update mx moe gemm1 gemm2 TF and BW calculation

This commit is contained in:
mtgu0705
2025-05-23 05:29:39 -05:00
parent d6bfdc9d7d
commit 2216ff0521
2 changed files with 12 additions and 10 deletions

View File

@@ -433,11 +433,12 @@ int main(int argc, char* argv[])
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(EDataType) * valid_tile_num * N;
std::size_t flop =
std::size_t(2) * tokens * N * 2 * K + std::size_t(2) * tokens * N * K / ScaleBlockSize;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K + sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(XDataType) * tokens * K + sizeof(XDataType) * K * N * 2 * experts +
sizeof(EDataType) * tokens * topk * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

View File

@@ -560,12 +560,13 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
std::size_t flop =
std::size_t(2) * tokens * N * 2 * K + std::size_t(2) * tokens * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * K * topk +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(EDataType) * tokens * N;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K + sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(XDataType) / 2 * tokens * K + sizeof(XDataType) / 2 * K * N * 2 * experts +
sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;