diff --git a/kt-kernel/operators/moe-tp.hpp b/kt-kernel/operators/moe-tp.hpp index 7130715..764e8c9 100644 --- a/kt-kernel/operators/moe-tp.hpp +++ b/kt-kernel/operators/moe-tp.hpp @@ -184,17 +184,25 @@ class TP_MOE_Common : public MoE_Interface { #ifdef FORWARD_TIME_REPORT auto end = std::chrono::high_resolution_clock::now(); auto forward_time = std::chrono::duration_cast(end - start).count(); - auto band_width = (1.0 * config.routed_expert_num * config.hidden_size * config.intermediate_size * 3 / 1e9) / - (1.0 * forward_time / 1e6); + int unique_experts = 0; + { + std::unordered_set expert_set; + for (int i = 0; i < qlen * config.num_experts_per_tok; i++) { + expert_set.insert(expert_ids[i]); + } + unique_experts = expert_set.size(); + } + auto band_width = + (1.0 * unique_experts * config.hidden_size * config.intermediate_size * 3 / 1e9) / (1.0 * forward_time / 1e6); auto GFLOPS = - (1.0 * config.hidden_size * config.intermediate_size * qlen * 3 * config.routed_expert_num * 2 / 1e9) / + (1.0 * config.hidden_size * config.intermediate_size * qlen * 3 * config.num_experts_per_tok * 2 / 1e9) / (1.0 * forward_time / 1e6); if (qlen <= 10) { forward_time_sum_ns += forward_time; forward_count++; } auto average_bandwidth = - (1.0 * forward_count * config.routed_expert_num * config.hidden_size * config.intermediate_size * 3 / 1e9) / + (1.0 * forward_count * unique_experts * config.hidden_size * config.intermediate_size * 3 / 1e9) / (1.0 * forward_time_sum_ns / 1e6); printf( "forward time %ld, time stamp:%ld, band width %f GElement/s, ave bandwidth %f GElement/s (only "