diff --git a/compare_results.py b/compare_results.py new file mode 100755 index 0000000000..4b8a69ab32 --- /dev/null +++ b/compare_results.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +""" +BF16 最差 30 Case: Primus-Turbo vs CK Grouped GEMM 对比脚本 +从 validation_results.log 提取 CK 测试结果,与 Primus-Turbo 原始数据对比 +""" + +import re +import sys + +# Primus-Turbo 原始基准数据 +PT_DATA = { + 1: ("DeepSeek-V2-Lite-Down", 2, 512, 2048, 1408, 88.31, 77.83), + 2: ("DeepSeek-V2-Lite-GateUP", 2, 512, 2816, 2048, 128.56, 111.88), + 3: ("DeepSeek-V2-Lite-Down", 4, 512, 2048, 1408, 169.94, 153.80), + 4: ("DeepSeek-V2-Lite-Down", 2, 1024, 2048, 1408, 171.98, 155.41), + 5: ("Mixtral-8x7B-Down", 1, 512, 4096, 14336, 109.59, 235.06), + 6: ("Qwen3-30B-A3B-Down", 4, 512, 2048, 2048, 180.66, 167.65), + 7: ("DeepSeek-V2-Lite-GateUP", 4, 512, 2816, 2048, 235.38, 165.31), + 8: ("Mixtral-8x22B-Down", 1, 512, 6144, 16384, 162.22, 245.13), + 9: ("DeepSeek-V2-Lite-GateUP", 2, 1024, 2816, 2048, 240.26, 185.67), + 10: ("Grok-2-Down", 1, 512, 8192, 16384, 213.48, 249.60), + 11: ("Qwen3-30B-A3B-GateUP", 4, 512, 4096, 2048, 302.33, 181.72), + 12: ("DeepSeek-V2-Lite-Down", 8, 512, 2048, 1408, 274.33, 212.22), + 13: ("Qwen3-30B-A3B-Down", 8, 512, 2048, 2048, 289.45, 219.50), + 14: ("DeepSeek-V2-Lite-Down", 4, 1024, 2048, 1408, 282.12, 232.32), + 15: ("Mixtral-8x7B-Down", 1, 1024, 4096, 14336, 212.73, 337.73), + 16: ("Qwen3-30B-A3B-Down", 4, 1024, 2048, 2048, 297.60, 253.24), + 17: ("DeepSeek-V2-Lite-Down", 2, 2048, 2048, 1408, 293.07, 262.01), + 18: ("DeepSeek-V2-Down", 5, 512, 5120, 1536, 378.06, 180.62), + 19: ("Qwen3-235B-A22B-Down", 4, 512, 4096, 4096, 330.68, 239.15), + 20: ("DeepSeek-V2-Lite-GateUP", 8, 512, 2816, 2048, 350.89, 223.60), + 21: ("DeepSeek-V2-GateUP", 5, 512, 3072, 5120, 310.56, 265.50), + 22: ("DeepSeek-V2-Down", 10, 512, 5120, 1536, 354.81, 238.12), + 23: ("Mixtral-8x7B-GateUP", 1, 512, 28672, 4096, 449.17, 144.20), + 24: ("DeepSeek-V2-Lite-GateUP", 4, 1024, 2816, 2048, 364.11, 241.49), + 25: ("Mixtral-8x22B-GateUP", 1, 512, 32768, 6144, 457.38, 179.06), + 26: ("Mixtral-8x22B-Down", 1, 1024, 6144, 16384, 292.92, 346.51), + 27: ("DeepSeek-V2-Lite-Down", 8, 1024, 2048, 1408, 395.28, 245.58), + 28: ("DeepSeek-V2-Down", 5, 1024, 5120, 1536, 367.96, 276.79), + 29: ("DeepSeek-V2-Lite-GateUP", 2, 2048, 2816, 2048, 376.12, 270.01), + 30: ("DeepSeek-V2-Lite-GateUP", 8, 1024, 2816, 2048, 337.25, 310.81), +} + + +def parse_log(log_file): + """解析 validation_results.log 提取 CK TFLOPS 数据""" + with open(log_file, 'r') as f: + content = f.read() + + # 提取所有 TFlops 值 + tflops_pattern = r'Perf:.*?(\d+\.?\d*) TFlops' + tflops_values = [float(x) for x in re.findall(tflops_pattern, content)] + + # 每个 rank 有 3 个值: Forward, Backward_A, Backward_B + ck_data = {} + for rank in range(1, 31): + idx = (rank - 1) * 3 + if idx + 2 < len(tflops_values): + ck_data[rank] = ( + tflops_values[idx], # Forward + tflops_values[idx + 1], # Backward grad_A + tflops_values[idx + 2], # Backward grad_B + ) + + # 统计信息 + correct_count = content.count("correct") + fail_count = content.count("fail") + tile_256 = content.count("256x256 tile") + tile_128 = content.count("256x128 tile") + + return ck_data, correct_count, fail_count, tile_256, tile_128 + + +def harmonic_mean(a, b): + """计算调和平均: 2 / (1/a + 1/b) + + 这是正确的方式来合并两个 TFLOPS 值,因为: + Combined_TFLOPS = Total_FLOPs / Total_Time + = (FLOPs_A + FLOPs_B) / (Time_A + Time_B) + = 2*FLOPs / (FLOPs/TFLOPS_A + FLOPs/TFLOPS_B) (当 FLOPs_A = FLOPs_B 时) + = 2 / (1/TFLOPS_A + 1/TFLOPS_B) + """ + if a <= 0 or b <= 0: + return 0.0 + return 2.0 / (1.0/a + 1.0/b) + + +def print_comparison(ck_data): + """打印对比表格""" + sep = "=" * 195 + line = "-" * 195 + + print(sep) + print("BF16 最差 30 Case: Primus-Turbo vs CK Grouped GEMM 完整对比") + print(sep) + print(f"{'Rank':<5} {'Case':<28} {'B':<3} {'M':<5} {'N':<6} {'K':<6} | {'CK_Fwd':>8} {'CK_BwdA':>8} {'CK_BwdB':>8} {'CK_Bwd':>8} | {'PT_Fwd':>8} {'PT_Bwd':>8} | {'Δ Fwd':>8} {'Δ Bwd':>8}") + print(line) + + total_ck_fwd = total_ck_bwd = 0 + total_pt_fwd = total_pt_bwd = 0 + + for rank in range(1, 31): + case, B, M, N, K, pt_fwd, pt_bwd = PT_DATA[rank] + + if rank in ck_data: + ck_fwd, ck_bwd_a, ck_bwd_b = ck_data[rank] + else: + ck_fwd = ck_bwd_a = ck_bwd_b = 0.0 + + # 使用调和平均计算综合 backward TFLOPS (正确的合并方式) + ck_bwd_combined = harmonic_mean(ck_bwd_a, ck_bwd_b) + + delta_fwd = ck_fwd - pt_fwd + delta_bwd = ck_bwd_combined - pt_bwd + + total_ck_fwd += ck_fwd + total_ck_bwd += ck_bwd_combined + total_pt_fwd += pt_fwd + total_pt_bwd += pt_bwd + + print(f"{rank:<5} {case:<28} {B:<3} {M:<5} {N:<6} {K:<6} | {ck_fwd:>8.2f} {ck_bwd_a:>8.2f} {ck_bwd_b:>8.2f} {ck_bwd_combined:>8.2f} | {pt_fwd:>8.2f} {pt_bwd:>8.2f} | {delta_fwd:>+8.2f} {delta_bwd:>+8.2f}") + + print(line) + + avg_ck_fwd = total_ck_fwd / 30 + avg_ck_bwd = total_ck_bwd / 30 + avg_pt_fwd = total_pt_fwd / 30 + avg_pt_bwd = total_pt_bwd / 30 + + print(f"{'平均':<38} | {avg_ck_fwd:>8.2f} {avg_ck_bwd:>26.2f} | {avg_pt_fwd:>8.2f} {avg_pt_bwd:>8.2f} | {avg_ck_fwd-avg_pt_fwd:>+8.2f} {avg_ck_bwd-avg_pt_bwd:>+8.2f}") + print(sep) + + return avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd + + +def print_summary(avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd, + correct_count, fail_count, tile_256, tile_128): + """打印总结""" + sep = "=" * 100 + + print() + print(sep) + print("性能对比总结") + print(sep) + print(f"Forward 平均: PT = {avg_pt_fwd:.2f} TFLOPS → CK = {avg_ck_fwd:.2f} TFLOPS ({(avg_ck_fwd/avg_pt_fwd-1)*100:>+.1f}%)") + print(f"Backward 平均: PT = {avg_pt_bwd:.2f} TFLOPS → CK = {avg_ck_bwd:.2f} TFLOPS ({(avg_ck_bwd/avg_pt_bwd-1)*100:>+.1f}%)") + + avg_pt = (avg_pt_fwd + avg_pt_bwd) / 2 + avg_ck = (avg_ck_fwd + avg_ck_bwd) / 2 + print(f"综合平均: PT = {avg_pt:.2f} TFLOPS → CK = {avg_ck:.2f} TFLOPS ({(avg_ck/avg_pt-1)*100:>+.1f}%)") + print(sep) + + print() + print(f"精度验证: {correct_count} 通过, {fail_count} 失败") + print(f"Tile 配置: 256x256 使用 {tile_256} 次, 256x128 使用 {tile_128} 次") + + +def main(): + log_file = sys.argv[1] if len(sys.argv) > 1 else "validation_results.log" + + try: + ck_data, correct_count, fail_count, tile_256, tile_128 = parse_log(log_file) + except FileNotFoundError: + print(f"错误: 找不到文件 {log_file}") + print("用法: python3 compare_results.py [validation_results.log]") + sys.exit(1) + + avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd = print_comparison(ck_data) + print_summary(avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd, + correct_count, fail_count, tile_256, tile_128) + + +if __name__ == "__main__": + main() + diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index a24e4bc8ab..2756c7f432 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -233,14 +233,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a CDataType, AccDataType>(argc, argv, Col{}, Row{}, Row{}); } - else if(a_layout == "C" && b_layout == "C") - { - return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } else { throw std::runtime_error("Unsupported data layout configuration for A and B tensors!"); @@ -260,20 +252,65 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") - { - return run_gemm_example_prec_type, ck_tile::half_t>( - a_layout, b_layout, argc, argv); - } - else if(data_type == "bf16") + if(data_type == "bf16") { return run_gemm_example_prec_type, ck_tile::bf16_t>( a_layout, b_layout, argc, argv); } - else if(data_type == "fp8") + else { - return run_gemm_example_prec_type, ck_tile::fp8_t>( - a_layout, b_layout, argc, argv); + throw std::runtime_error("Unsupported data type configuration."); + } +} + +// Determine appropriate tile config based on N alignment +int run_grouped_gemm_example_with_n_check(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string data_type = arg_parser.get_str("prec"); + const int group_count = arg_parser.get_int("group_count"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + + // Check N alignment for all groups + bool all_n_mod_256 = true; + bool all_n_mod_128 = true; + + if(Ns.size() == static_cast(group_count)) + { + for(const auto& n : Ns) + { + if(n % 256 != 0) + all_n_mod_256 = false; + if(n % 128 != 0) + all_n_mod_128 = false; + } + } + + if(data_type == "bf16") + { + if(all_n_mod_256) + { + std::cout << "[Config] Using 256x256 tile (N % 256 == 0)" << std::endl; + return run_gemm_example_prec_type, ck_tile::bf16_t>( + a_layout, b_layout, argc, argv); + } + else if(all_n_mod_128) + { + std::cout << "[Config] Using 256x128 tile (N % 128 == 0, N % 256 != 0)" << std::endl; + return run_gemm_example_prec_type, ck_tile::bf16_t>( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported error."); + } } else { @@ -283,11 +320,5 @@ int run_grouped_gemm_example(int argc, char* argv[]) int main(int argc, char* argv[]) { -#if CK_TILE_USE_WMMA - return !run_grouped_gemm_example(argc, argv); -#else - return !run_grouped_gemm_example(argc, argv) || - !run_grouped_gemm_example(argc, argv) || - !run_grouped_gemm_example(argc, argv); -#endif + return !run_grouped_gemm_example_with_n_check(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 462f11e405..6ddb87ad0f 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -65,9 +65,72 @@ struct GemmConfigBase }; template -struct GemmConfigComputeV3_2 : public GemmConfigBase +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + + static constexpr int kBlockPerCu = 1; +}; + +// 256x128 tile config for N % 128 == 0 but N % 256 != 0 +template +struct GemmConfigComputeV3_256x128 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); @@ -286,7 +349,7 @@ std::pair create_args(int argc, char* argv[]) .insert("b_layout", "C", "B tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") diff --git a/run_worst_30_bf16_cases.sh b/run_worst_30_bf16_cases.sh new file mode 100755 index 0000000000..f74569a96e --- /dev/null +++ b/run_worst_30_bf16_cases.sh @@ -0,0 +1,126 @@ +#!/bin/bash + +# BF16 表现最差的前 30 个 Case 测试脚本 (Forward + Backward) +# 按平均 TFLOPS 从低到高排序 + +BINARY="./build/bin/tile_example_grouped_gemm" + +# 检查二进制文件是否存在 +if [ ! -f "$BINARY" ]; then + echo "Error: $BINARY not found!" + echo "Please build the example first with: make tile_example_grouped_gemm -j\$(nproc)" + exit 1 +fi + +echo "========================================================================================================" +echo "Running BF16 Worst 30 Cases - Grouped GEMM Benchmark (Forward + Backward)" +echo "========================================================================================================" +echo "" + +# 生成重复参数的函数 +repeat_param() { + local val=$1 + local count=$2 + local result="" + for ((i=0; i