update test config

This commit is contained in:
kyle-256
2026-01-09 08:24:15 +00:00
parent a12d808505
commit 726ddd64ad
4 changed files with 421 additions and 26 deletions

175
compare_results.py Executable file
View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python3
"""
BF16 最差 30 Case: Primus-Turbo vs CK Grouped GEMM 对比脚本
从 validation_results.log 提取 CK 测试结果,与 Primus-Turbo 原始数据对比
"""
import re
import sys
# Primus-Turbo 原始基准数据
PT_DATA = {
1: ("DeepSeek-V2-Lite-Down", 2, 512, 2048, 1408, 88.31, 77.83),
2: ("DeepSeek-V2-Lite-GateUP", 2, 512, 2816, 2048, 128.56, 111.88),
3: ("DeepSeek-V2-Lite-Down", 4, 512, 2048, 1408, 169.94, 153.80),
4: ("DeepSeek-V2-Lite-Down", 2, 1024, 2048, 1408, 171.98, 155.41),
5: ("Mixtral-8x7B-Down", 1, 512, 4096, 14336, 109.59, 235.06),
6: ("Qwen3-30B-A3B-Down", 4, 512, 2048, 2048, 180.66, 167.65),
7: ("DeepSeek-V2-Lite-GateUP", 4, 512, 2816, 2048, 235.38, 165.31),
8: ("Mixtral-8x22B-Down", 1, 512, 6144, 16384, 162.22, 245.13),
9: ("DeepSeek-V2-Lite-GateUP", 2, 1024, 2816, 2048, 240.26, 185.67),
10: ("Grok-2-Down", 1, 512, 8192, 16384, 213.48, 249.60),
11: ("Qwen3-30B-A3B-GateUP", 4, 512, 4096, 2048, 302.33, 181.72),
12: ("DeepSeek-V2-Lite-Down", 8, 512, 2048, 1408, 274.33, 212.22),
13: ("Qwen3-30B-A3B-Down", 8, 512, 2048, 2048, 289.45, 219.50),
14: ("DeepSeek-V2-Lite-Down", 4, 1024, 2048, 1408, 282.12, 232.32),
15: ("Mixtral-8x7B-Down", 1, 1024, 4096, 14336, 212.73, 337.73),
16: ("Qwen3-30B-A3B-Down", 4, 1024, 2048, 2048, 297.60, 253.24),
17: ("DeepSeek-V2-Lite-Down", 2, 2048, 2048, 1408, 293.07, 262.01),
18: ("DeepSeek-V2-Down", 5, 512, 5120, 1536, 378.06, 180.62),
19: ("Qwen3-235B-A22B-Down", 4, 512, 4096, 4096, 330.68, 239.15),
20: ("DeepSeek-V2-Lite-GateUP", 8, 512, 2816, 2048, 350.89, 223.60),
21: ("DeepSeek-V2-GateUP", 5, 512, 3072, 5120, 310.56, 265.50),
22: ("DeepSeek-V2-Down", 10, 512, 5120, 1536, 354.81, 238.12),
23: ("Mixtral-8x7B-GateUP", 1, 512, 28672, 4096, 449.17, 144.20),
24: ("DeepSeek-V2-Lite-GateUP", 4, 1024, 2816, 2048, 364.11, 241.49),
25: ("Mixtral-8x22B-GateUP", 1, 512, 32768, 6144, 457.38, 179.06),
26: ("Mixtral-8x22B-Down", 1, 1024, 6144, 16384, 292.92, 346.51),
27: ("DeepSeek-V2-Lite-Down", 8, 1024, 2048, 1408, 395.28, 245.58),
28: ("DeepSeek-V2-Down", 5, 1024, 5120, 1536, 367.96, 276.79),
29: ("DeepSeek-V2-Lite-GateUP", 2, 2048, 2816, 2048, 376.12, 270.01),
30: ("DeepSeek-V2-Lite-GateUP", 8, 1024, 2816, 2048, 337.25, 310.81),
}
def parse_log(log_file):
"""解析 validation_results.log 提取 CK TFLOPS 数据"""
with open(log_file, 'r') as f:
content = f.read()
# 提取所有 TFlops 值
tflops_pattern = r'Perf:.*?(\d+\.?\d*) TFlops'
tflops_values = [float(x) for x in re.findall(tflops_pattern, content)]
# 每个 rank 有 3 个值: Forward, Backward_A, Backward_B
ck_data = {}
for rank in range(1, 31):
idx = (rank - 1) * 3
if idx + 2 < len(tflops_values):
ck_data[rank] = (
tflops_values[idx], # Forward
tflops_values[idx + 1], # Backward grad_A
tflops_values[idx + 2], # Backward grad_B
)
# 统计信息
correct_count = content.count("correct")
fail_count = content.count("fail")
tile_256 = content.count("256x256 tile")
tile_128 = content.count("256x128 tile")
return ck_data, correct_count, fail_count, tile_256, tile_128
def harmonic_mean(a, b):
"""计算调和平均: 2 / (1/a + 1/b)
这是正确的方式来合并两个 TFLOPS 值,因为:
Combined_TFLOPS = Total_FLOPs / Total_Time
= (FLOPs_A + FLOPs_B) / (Time_A + Time_B)
= 2*FLOPs / (FLOPs/TFLOPS_A + FLOPs/TFLOPS_B) (当 FLOPs_A = FLOPs_B 时)
= 2 / (1/TFLOPS_A + 1/TFLOPS_B)
"""
if a <= 0 or b <= 0:
return 0.0
return 2.0 / (1.0/a + 1.0/b)
def print_comparison(ck_data):
"""打印对比表格"""
sep = "=" * 195
line = "-" * 195
print(sep)
print("BF16 最差 30 Case: Primus-Turbo vs CK Grouped GEMM 完整对比")
print(sep)
print(f"{'Rank':<5} {'Case':<28} {'B':<3} {'M':<5} {'N':<6} {'K':<6} | {'CK_Fwd':>8} {'CK_BwdA':>8} {'CK_BwdB':>8} {'CK_Bwd':>8} | {'PT_Fwd':>8} {'PT_Bwd':>8} | {'Δ Fwd':>8} {'Δ Bwd':>8}")
print(line)
total_ck_fwd = total_ck_bwd = 0
total_pt_fwd = total_pt_bwd = 0
for rank in range(1, 31):
case, B, M, N, K, pt_fwd, pt_bwd = PT_DATA[rank]
if rank in ck_data:
ck_fwd, ck_bwd_a, ck_bwd_b = ck_data[rank]
else:
ck_fwd = ck_bwd_a = ck_bwd_b = 0.0
# 使用调和平均计算综合 backward TFLOPS (正确的合并方式)
ck_bwd_combined = harmonic_mean(ck_bwd_a, ck_bwd_b)
delta_fwd = ck_fwd - pt_fwd
delta_bwd = ck_bwd_combined - pt_bwd
total_ck_fwd += ck_fwd
total_ck_bwd += ck_bwd_combined
total_pt_fwd += pt_fwd
total_pt_bwd += pt_bwd
print(f"{rank:<5} {case:<28} {B:<3} {M:<5} {N:<6} {K:<6} | {ck_fwd:>8.2f} {ck_bwd_a:>8.2f} {ck_bwd_b:>8.2f} {ck_bwd_combined:>8.2f} | {pt_fwd:>8.2f} {pt_bwd:>8.2f} | {delta_fwd:>+8.2f} {delta_bwd:>+8.2f}")
print(line)
avg_ck_fwd = total_ck_fwd / 30
avg_ck_bwd = total_ck_bwd / 30
avg_pt_fwd = total_pt_fwd / 30
avg_pt_bwd = total_pt_bwd / 30
print(f"{'平均':<38} | {avg_ck_fwd:>8.2f} {avg_ck_bwd:>26.2f} | {avg_pt_fwd:>8.2f} {avg_pt_bwd:>8.2f} | {avg_ck_fwd-avg_pt_fwd:>+8.2f} {avg_ck_bwd-avg_pt_bwd:>+8.2f}")
print(sep)
return avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd
def print_summary(avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd,
correct_count, fail_count, tile_256, tile_128):
"""打印总结"""
sep = "=" * 100
print()
print(sep)
print("性能对比总结")
print(sep)
print(f"Forward 平均: PT = {avg_pt_fwd:.2f} TFLOPS → CK = {avg_ck_fwd:.2f} TFLOPS ({(avg_ck_fwd/avg_pt_fwd-1)*100:>+.1f}%)")
print(f"Backward 平均: PT = {avg_pt_bwd:.2f} TFLOPS → CK = {avg_ck_bwd:.2f} TFLOPS ({(avg_ck_bwd/avg_pt_bwd-1)*100:>+.1f}%)")
avg_pt = (avg_pt_fwd + avg_pt_bwd) / 2
avg_ck = (avg_ck_fwd + avg_ck_bwd) / 2
print(f"综合平均: PT = {avg_pt:.2f} TFLOPS → CK = {avg_ck:.2f} TFLOPS ({(avg_ck/avg_pt-1)*100:>+.1f}%)")
print(sep)
print()
print(f"精度验证: {correct_count} 通过, {fail_count} 失败")
print(f"Tile 配置: 256x256 使用 {tile_256} 次, 256x128 使用 {tile_128}")
def main():
log_file = sys.argv[1] if len(sys.argv) > 1 else "validation_results.log"
try:
ck_data, correct_count, fail_count, tile_256, tile_128 = parse_log(log_file)
except FileNotFoundError:
print(f"错误: 找不到文件 {log_file}")
print("用法: python3 compare_results.py [validation_results.log]")
sys.exit(1)
avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd = print_comparison(ck_data)
print_summary(avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd,
correct_count, fail_count, tile_256, tile_128)
if __name__ == "__main__":
main()

View File

@@ -233,14 +233,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
@@ -260,20 +252,65 @@ int run_grouped_gemm_example(int argc, char* argv[])
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
throw std::runtime_error("Unsupported data type configuration.");
}
}
// Determine appropriate tile config based on N alignment
int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
const int group_count = arg_parser.get_int("group_count");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
// Check N alignment for all groups
bool all_n_mod_256 = true;
bool all_n_mod_128 = true;
if(Ns.size() == static_cast<size_t>(group_count))
{
for(const auto& n : Ns)
{
if(n % 256 != 0)
all_n_mod_256 = false;
if(n % 128 != 0)
all_n_mod_128 = false;
}
}
if(data_type == "bf16")
{
if(all_n_mod_256)
{
std::cout << "[Config] Using 256x256 tile (N % 256 == 0)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_2<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(all_n_mod_128)
{
std::cout << "[Config] Using 256x128 tile (N % 128 == 0, N % 256 != 0)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported error.");
}
}
else
{
@@ -283,11 +320,5 @@ int run_grouped_gemm_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV4_V2>(argc, argv);
#endif
return !run_grouped_gemm_example_with_n_check(argc, argv);
}

View File

@@ -65,9 +65,72 @@ struct GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
};
// 256x128 tile config for N % 128 == 0 but N % 256 != 0
template <typename PrecType>
struct GemmConfigComputeV3_256x128 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
@@ -286,7 +349,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")

126
run_worst_30_bf16_cases.sh Executable file
View File

@@ -0,0 +1,126 @@
#!/bin/bash
# BF16 表现最差的前 30 个 Case 测试脚本 (Forward + Backward)
# 按平均 TFLOPS 从低到高排序
BINARY="./build/bin/tile_example_grouped_gemm"
# 检查二进制文件是否存在
if [ ! -f "$BINARY" ]; then
echo "Error: $BINARY not found!"
echo "Please build the example first with: make tile_example_grouped_gemm -j\$(nproc)"
exit 1
fi
echo "========================================================================================================"
echo "Running BF16 Worst 30 Cases - Grouped GEMM Benchmark (Forward + Backward)"
echo "========================================================================================================"
echo ""
# 生成重复参数的函数
repeat_param() {
local val=$1
local count=$2
local result=""
for ((i=0; i<count; i++)); do
if [ -z "$result" ]; then
result="$val"
else
result="$result,$val"
fi
done
echo "$result"
}
# 运行单个测试 (Forward + Backward)
# 用法: run_test rank testid case B M N K
run_test() {
local rank=$1
local testid=$2
local case_name=$3
local B=$4
local M=$5
local N=$6
local K=$7
echo "========================================================================================================"
echo "Rank $rank: $case_name (TestID=$testid)"
echo " B=$B, M=$M, N=$N, K=$K"
echo "========================================================================================================"
# ==================== Forward ====================
# Forward: (M, K) @ (K, N) = (M, N)
local fwd_Ms=$(repeat_param $M $B)
local fwd_Ns=$(repeat_param $N $B)
local fwd_Ks=$(repeat_param $K $B)
local strides=$(repeat_param 0 $B)
echo ""
echo "[Forward] GEMM: M=$M, N=$N, K=$K"
echo " Command: $BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
$BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
# ==================== Backward grad_A ====================
# grad_A = grad_Y @ W^T
# (M, N) @ (N, K) = (M, K)
# GEMM: M=M, N=K, K=N
local bwd_a_Ms=$(repeat_param $M $B)
local bwd_a_Ns=$(repeat_param $K $B)
local bwd_a_Ks=$(repeat_param $N $B)
echo ""
echo "[Backward grad_A] GEMM: M=$M, N=$K, K=$N"
echo " Command: $BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
$BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
# ==================== Backward grad_B ====================
# grad_B = X^T @ grad_Y
# (K, M) @ (M, N) = (K, N)
# GEMM: M=K, N=N, K=M
local bwd_b_Ms=$(repeat_param $K $B)
local bwd_b_Ns=$(repeat_param $N $B)
local bwd_b_Ks=$(repeat_param $M $B)
echo ""
echo "[Backward grad_B] GEMM: M=$K, N=$N, K=$M"
echo " Command: $BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
$BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
echo ""
}
# 运行所有30个测试用例
run_test 1 62 "DeepSeek-V2-Lite-Down" 2 512 2048 1408
run_test 2 61 "DeepSeek-V2-Lite-GateUP" 2 512 2816 2048
run_test 3 72 "DeepSeek-V2-Lite-Down" 4 512 2048 1408
run_test 4 64 "DeepSeek-V2-Lite-Down" 2 1024 2048 1408
run_test 5 162 "Mixtral-8x7B-Down" 1 512 4096 14336
run_test 6 102 "Qwen3-30B-A3B-Down" 4 512 2048 2048
run_test 7 71 "DeepSeek-V2-Lite-GateUP" 4 512 2816 2048
run_test 8 172 "Mixtral-8x22B-Down" 1 512 6144 16384
run_test 9 63 "DeepSeek-V2-Lite-GateUP" 2 1024 2816 2048
run_test 10 92 "Grok-2-Down" 1 512 8192 16384
run_test 11 101 "Qwen3-30B-A3B-GateUP" 4 512 4096 2048
run_test 12 82 "DeepSeek-V2-Lite-Down" 8 512 2048 1408
run_test 13 112 "Qwen3-30B-A3B-Down" 8 512 2048 2048
run_test 14 74 "DeepSeek-V2-Lite-Down" 4 1024 2048 1408
run_test 15 164 "Mixtral-8x7B-Down" 1 1024 4096 14336
run_test 16 104 "Qwen3-30B-A3B-Down" 4 1024 2048 2048
run_test 17 66 "DeepSeek-V2-Lite-Down" 2 2048 2048 1408
run_test 18 32 "DeepSeek-V2-Down" 5 512 5120 1536
run_test 19 132 "Qwen3-235B-A22B-Down" 4 512 4096 4096
run_test 20 81 "DeepSeek-V2-Lite-GateUP" 8 512 2816 2048
run_test 21 31 "DeepSeek-V2-GateUP" 5 512 3072 5120
run_test 22 42 "DeepSeek-V2-Down" 10 512 5120 1536
run_test 23 161 "Mixtral-8x7B-GateUP" 1 512 28672 4096
run_test 24 73 "DeepSeek-V2-Lite-GateUP" 4 1024 2816 2048
run_test 25 171 "Mixtral-8x22B-GateUP" 1 512 32768 6144
run_test 26 174 "Mixtral-8x22B-Down" 1 1024 6144 16384
run_test 27 84 "DeepSeek-V2-Lite-Down" 8 1024 2048 1408
run_test 28 34 "DeepSeek-V2-Down" 5 1024 5120 1536
run_test 29 65 "DeepSeek-V2-Lite-GateUP" 2 2048 2816 2048
run_test 30 83 "DeepSeek-V2-Lite-GateUP" 8 1024 2816 2048
echo "========================================================================================================"
echo "All tests completed!"
echo "========================================================================================================"