mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
update test config
This commit is contained in:
175
compare_results.py
Executable file
175
compare_results.py
Executable file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BF16 最差 30 Case: Primus-Turbo vs CK Grouped GEMM 对比脚本
|
||||
从 validation_results.log 提取 CK 测试结果,与 Primus-Turbo 原始数据对比
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
# Primus-Turbo 原始基准数据
|
||||
PT_DATA = {
|
||||
1: ("DeepSeek-V2-Lite-Down", 2, 512, 2048, 1408, 88.31, 77.83),
|
||||
2: ("DeepSeek-V2-Lite-GateUP", 2, 512, 2816, 2048, 128.56, 111.88),
|
||||
3: ("DeepSeek-V2-Lite-Down", 4, 512, 2048, 1408, 169.94, 153.80),
|
||||
4: ("DeepSeek-V2-Lite-Down", 2, 1024, 2048, 1408, 171.98, 155.41),
|
||||
5: ("Mixtral-8x7B-Down", 1, 512, 4096, 14336, 109.59, 235.06),
|
||||
6: ("Qwen3-30B-A3B-Down", 4, 512, 2048, 2048, 180.66, 167.65),
|
||||
7: ("DeepSeek-V2-Lite-GateUP", 4, 512, 2816, 2048, 235.38, 165.31),
|
||||
8: ("Mixtral-8x22B-Down", 1, 512, 6144, 16384, 162.22, 245.13),
|
||||
9: ("DeepSeek-V2-Lite-GateUP", 2, 1024, 2816, 2048, 240.26, 185.67),
|
||||
10: ("Grok-2-Down", 1, 512, 8192, 16384, 213.48, 249.60),
|
||||
11: ("Qwen3-30B-A3B-GateUP", 4, 512, 4096, 2048, 302.33, 181.72),
|
||||
12: ("DeepSeek-V2-Lite-Down", 8, 512, 2048, 1408, 274.33, 212.22),
|
||||
13: ("Qwen3-30B-A3B-Down", 8, 512, 2048, 2048, 289.45, 219.50),
|
||||
14: ("DeepSeek-V2-Lite-Down", 4, 1024, 2048, 1408, 282.12, 232.32),
|
||||
15: ("Mixtral-8x7B-Down", 1, 1024, 4096, 14336, 212.73, 337.73),
|
||||
16: ("Qwen3-30B-A3B-Down", 4, 1024, 2048, 2048, 297.60, 253.24),
|
||||
17: ("DeepSeek-V2-Lite-Down", 2, 2048, 2048, 1408, 293.07, 262.01),
|
||||
18: ("DeepSeek-V2-Down", 5, 512, 5120, 1536, 378.06, 180.62),
|
||||
19: ("Qwen3-235B-A22B-Down", 4, 512, 4096, 4096, 330.68, 239.15),
|
||||
20: ("DeepSeek-V2-Lite-GateUP", 8, 512, 2816, 2048, 350.89, 223.60),
|
||||
21: ("DeepSeek-V2-GateUP", 5, 512, 3072, 5120, 310.56, 265.50),
|
||||
22: ("DeepSeek-V2-Down", 10, 512, 5120, 1536, 354.81, 238.12),
|
||||
23: ("Mixtral-8x7B-GateUP", 1, 512, 28672, 4096, 449.17, 144.20),
|
||||
24: ("DeepSeek-V2-Lite-GateUP", 4, 1024, 2816, 2048, 364.11, 241.49),
|
||||
25: ("Mixtral-8x22B-GateUP", 1, 512, 32768, 6144, 457.38, 179.06),
|
||||
26: ("Mixtral-8x22B-Down", 1, 1024, 6144, 16384, 292.92, 346.51),
|
||||
27: ("DeepSeek-V2-Lite-Down", 8, 1024, 2048, 1408, 395.28, 245.58),
|
||||
28: ("DeepSeek-V2-Down", 5, 1024, 5120, 1536, 367.96, 276.79),
|
||||
29: ("DeepSeek-V2-Lite-GateUP", 2, 2048, 2816, 2048, 376.12, 270.01),
|
||||
30: ("DeepSeek-V2-Lite-GateUP", 8, 1024, 2816, 2048, 337.25, 310.81),
|
||||
}
|
||||
|
||||
|
||||
def parse_log(log_file):
|
||||
"""解析 validation_results.log 提取 CK TFLOPS 数据"""
|
||||
with open(log_file, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# 提取所有 TFlops 值
|
||||
tflops_pattern = r'Perf:.*?(\d+\.?\d*) TFlops'
|
||||
tflops_values = [float(x) for x in re.findall(tflops_pattern, content)]
|
||||
|
||||
# 每个 rank 有 3 个值: Forward, Backward_A, Backward_B
|
||||
ck_data = {}
|
||||
for rank in range(1, 31):
|
||||
idx = (rank - 1) * 3
|
||||
if idx + 2 < len(tflops_values):
|
||||
ck_data[rank] = (
|
||||
tflops_values[idx], # Forward
|
||||
tflops_values[idx + 1], # Backward grad_A
|
||||
tflops_values[idx + 2], # Backward grad_B
|
||||
)
|
||||
|
||||
# 统计信息
|
||||
correct_count = content.count("correct")
|
||||
fail_count = content.count("fail")
|
||||
tile_256 = content.count("256x256 tile")
|
||||
tile_128 = content.count("256x128 tile")
|
||||
|
||||
return ck_data, correct_count, fail_count, tile_256, tile_128
|
||||
|
||||
|
||||
def harmonic_mean(a, b):
|
||||
"""计算调和平均: 2 / (1/a + 1/b)
|
||||
|
||||
这是正确的方式来合并两个 TFLOPS 值,因为:
|
||||
Combined_TFLOPS = Total_FLOPs / Total_Time
|
||||
= (FLOPs_A + FLOPs_B) / (Time_A + Time_B)
|
||||
= 2*FLOPs / (FLOPs/TFLOPS_A + FLOPs/TFLOPS_B) (当 FLOPs_A = FLOPs_B 时)
|
||||
= 2 / (1/TFLOPS_A + 1/TFLOPS_B)
|
||||
"""
|
||||
if a <= 0 or b <= 0:
|
||||
return 0.0
|
||||
return 2.0 / (1.0/a + 1.0/b)
|
||||
|
||||
|
||||
def print_comparison(ck_data):
|
||||
"""打印对比表格"""
|
||||
sep = "=" * 195
|
||||
line = "-" * 195
|
||||
|
||||
print(sep)
|
||||
print("BF16 最差 30 Case: Primus-Turbo vs CK Grouped GEMM 完整对比")
|
||||
print(sep)
|
||||
print(f"{'Rank':<5} {'Case':<28} {'B':<3} {'M':<5} {'N':<6} {'K':<6} | {'CK_Fwd':>8} {'CK_BwdA':>8} {'CK_BwdB':>8} {'CK_Bwd':>8} | {'PT_Fwd':>8} {'PT_Bwd':>8} | {'Δ Fwd':>8} {'Δ Bwd':>8}")
|
||||
print(line)
|
||||
|
||||
total_ck_fwd = total_ck_bwd = 0
|
||||
total_pt_fwd = total_pt_bwd = 0
|
||||
|
||||
for rank in range(1, 31):
|
||||
case, B, M, N, K, pt_fwd, pt_bwd = PT_DATA[rank]
|
||||
|
||||
if rank in ck_data:
|
||||
ck_fwd, ck_bwd_a, ck_bwd_b = ck_data[rank]
|
||||
else:
|
||||
ck_fwd = ck_bwd_a = ck_bwd_b = 0.0
|
||||
|
||||
# 使用调和平均计算综合 backward TFLOPS (正确的合并方式)
|
||||
ck_bwd_combined = harmonic_mean(ck_bwd_a, ck_bwd_b)
|
||||
|
||||
delta_fwd = ck_fwd - pt_fwd
|
||||
delta_bwd = ck_bwd_combined - pt_bwd
|
||||
|
||||
total_ck_fwd += ck_fwd
|
||||
total_ck_bwd += ck_bwd_combined
|
||||
total_pt_fwd += pt_fwd
|
||||
total_pt_bwd += pt_bwd
|
||||
|
||||
print(f"{rank:<5} {case:<28} {B:<3} {M:<5} {N:<6} {K:<6} | {ck_fwd:>8.2f} {ck_bwd_a:>8.2f} {ck_bwd_b:>8.2f} {ck_bwd_combined:>8.2f} | {pt_fwd:>8.2f} {pt_bwd:>8.2f} | {delta_fwd:>+8.2f} {delta_bwd:>+8.2f}")
|
||||
|
||||
print(line)
|
||||
|
||||
avg_ck_fwd = total_ck_fwd / 30
|
||||
avg_ck_bwd = total_ck_bwd / 30
|
||||
avg_pt_fwd = total_pt_fwd / 30
|
||||
avg_pt_bwd = total_pt_bwd / 30
|
||||
|
||||
print(f"{'平均':<38} | {avg_ck_fwd:>8.2f} {avg_ck_bwd:>26.2f} | {avg_pt_fwd:>8.2f} {avg_pt_bwd:>8.2f} | {avg_ck_fwd-avg_pt_fwd:>+8.2f} {avg_ck_bwd-avg_pt_bwd:>+8.2f}")
|
||||
print(sep)
|
||||
|
||||
return avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd
|
||||
|
||||
|
||||
def print_summary(avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd,
|
||||
correct_count, fail_count, tile_256, tile_128):
|
||||
"""打印总结"""
|
||||
sep = "=" * 100
|
||||
|
||||
print()
|
||||
print(sep)
|
||||
print("性能对比总结")
|
||||
print(sep)
|
||||
print(f"Forward 平均: PT = {avg_pt_fwd:.2f} TFLOPS → CK = {avg_ck_fwd:.2f} TFLOPS ({(avg_ck_fwd/avg_pt_fwd-1)*100:>+.1f}%)")
|
||||
print(f"Backward 平均: PT = {avg_pt_bwd:.2f} TFLOPS → CK = {avg_ck_bwd:.2f} TFLOPS ({(avg_ck_bwd/avg_pt_bwd-1)*100:>+.1f}%)")
|
||||
|
||||
avg_pt = (avg_pt_fwd + avg_pt_bwd) / 2
|
||||
avg_ck = (avg_ck_fwd + avg_ck_bwd) / 2
|
||||
print(f"综合平均: PT = {avg_pt:.2f} TFLOPS → CK = {avg_ck:.2f} TFLOPS ({(avg_ck/avg_pt-1)*100:>+.1f}%)")
|
||||
print(sep)
|
||||
|
||||
print()
|
||||
print(f"精度验证: {correct_count} 通过, {fail_count} 失败")
|
||||
print(f"Tile 配置: 256x256 使用 {tile_256} 次, 256x128 使用 {tile_128} 次")
|
||||
|
||||
|
||||
def main():
|
||||
log_file = sys.argv[1] if len(sys.argv) > 1 else "validation_results.log"
|
||||
|
||||
try:
|
||||
ck_data, correct_count, fail_count, tile_256, tile_128 = parse_log(log_file)
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 找不到文件 {log_file}")
|
||||
print("用法: python3 compare_results.py [validation_results.log]")
|
||||
sys.exit(1)
|
||||
|
||||
avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd = print_comparison(ck_data)
|
||||
print_summary(avg_ck_fwd, avg_ck_bwd, avg_pt_fwd, avg_pt_bwd,
|
||||
correct_count, fail_count, tile_256, tile_128)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -233,14 +233,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
|
||||
@@ -260,20 +252,65 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
// Determine appropriate tile config based on N alignment
|
||||
int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
|
||||
// Check N alignment for all groups
|
||||
bool all_n_mod_256 = true;
|
||||
bool all_n_mod_128 = true;
|
||||
|
||||
if(Ns.size() == static_cast<size_t>(group_count))
|
||||
{
|
||||
for(const auto& n : Ns)
|
||||
{
|
||||
if(n % 256 != 0)
|
||||
all_n_mod_256 = false;
|
||||
if(n % 128 != 0)
|
||||
all_n_mod_128 = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
if(all_n_mod_256)
|
||||
{
|
||||
std::cout << "[Config] Using 256x256 tile (N % 256 == 0)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_2<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(all_n_mod_128)
|
||||
{
|
||||
std::cout << "[Config] Using 256x128 tile (N % 128 == 0, N % 256 != 0)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported error.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -283,11 +320,5 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv) ||
|
||||
!run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv) ||
|
||||
!run_grouped_gemm_example<GemmConfigComputeV4_V2>(argc, argv);
|
||||
#endif
|
||||
return !run_grouped_gemm_example_with_n_check(argc, argv);
|
||||
}
|
||||
|
||||
@@ -65,9 +65,72 @@ struct GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
// 256x128 tile config for N % 128 == 0 but N % 256 != 0
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_256x128 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
@@ -286,7 +349,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
|
||||
126
run_worst_30_bf16_cases.sh
Executable file
126
run_worst_30_bf16_cases.sh
Executable file
@@ -0,0 +1,126 @@
|
||||
#!/bin/bash
|
||||
|
||||
# BF16 表现最差的前 30 个 Case 测试脚本 (Forward + Backward)
|
||||
# 按平均 TFLOPS 从低到高排序
|
||||
|
||||
BINARY="./build/bin/tile_example_grouped_gemm"
|
||||
|
||||
# 检查二进制文件是否存在
|
||||
if [ ! -f "$BINARY" ]; then
|
||||
echo "Error: $BINARY not found!"
|
||||
echo "Please build the example first with: make tile_example_grouped_gemm -j\$(nproc)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "========================================================================================================"
|
||||
echo "Running BF16 Worst 30 Cases - Grouped GEMM Benchmark (Forward + Backward)"
|
||||
echo "========================================================================================================"
|
||||
echo ""
|
||||
|
||||
# 生成重复参数的函数
|
||||
repeat_param() {
|
||||
local val=$1
|
||||
local count=$2
|
||||
local result=""
|
||||
for ((i=0; i<count; i++)); do
|
||||
if [ -z "$result" ]; then
|
||||
result="$val"
|
||||
else
|
||||
result="$result,$val"
|
||||
fi
|
||||
done
|
||||
echo "$result"
|
||||
}
|
||||
|
||||
# 运行单个测试 (Forward + Backward)
|
||||
# 用法: run_test rank testid case B M N K
|
||||
run_test() {
|
||||
local rank=$1
|
||||
local testid=$2
|
||||
local case_name=$3
|
||||
local B=$4
|
||||
local M=$5
|
||||
local N=$6
|
||||
local K=$7
|
||||
|
||||
echo "========================================================================================================"
|
||||
echo "Rank $rank: $case_name (TestID=$testid)"
|
||||
echo " B=$B, M=$M, N=$N, K=$K"
|
||||
echo "========================================================================================================"
|
||||
|
||||
# ==================== Forward ====================
|
||||
# Forward: (M, K) @ (K, N) = (M, N)
|
||||
local fwd_Ms=$(repeat_param $M $B)
|
||||
local fwd_Ns=$(repeat_param $N $B)
|
||||
local fwd_Ks=$(repeat_param $K $B)
|
||||
local strides=$(repeat_param 0 $B)
|
||||
|
||||
echo ""
|
||||
echo "[Forward] GEMM: M=$M, N=$N, K=$K"
|
||||
echo " Command: $BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
|
||||
$BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
|
||||
|
||||
# ==================== Backward grad_A ====================
|
||||
# grad_A = grad_Y @ W^T
|
||||
# (M, N) @ (N, K) = (M, K)
|
||||
# GEMM: M=M, N=K, K=N
|
||||
local bwd_a_Ms=$(repeat_param $M $B)
|
||||
local bwd_a_Ns=$(repeat_param $K $B)
|
||||
local bwd_a_Ks=$(repeat_param $N $B)
|
||||
|
||||
echo ""
|
||||
echo "[Backward grad_A] GEMM: M=$M, N=$K, K=$N"
|
||||
echo " Command: $BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
|
||||
$BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
|
||||
|
||||
# ==================== Backward grad_B ====================
|
||||
# grad_B = X^T @ grad_Y
|
||||
# (K, M) @ (M, N) = (K, N)
|
||||
# GEMM: M=K, N=N, K=M
|
||||
local bwd_b_Ms=$(repeat_param $K $B)
|
||||
local bwd_b_Ns=$(repeat_param $N $B)
|
||||
local bwd_b_Ks=$(repeat_param $M $B)
|
||||
|
||||
echo ""
|
||||
echo "[Backward grad_B] GEMM: M=$K, N=$N, K=$M"
|
||||
echo " Command: $BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
|
||||
$BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
|
||||
|
||||
echo ""
|
||||
}
|
||||
|
||||
# 运行所有30个测试用例
|
||||
run_test 1 62 "DeepSeek-V2-Lite-Down" 2 512 2048 1408
|
||||
run_test 2 61 "DeepSeek-V2-Lite-GateUP" 2 512 2816 2048
|
||||
run_test 3 72 "DeepSeek-V2-Lite-Down" 4 512 2048 1408
|
||||
run_test 4 64 "DeepSeek-V2-Lite-Down" 2 1024 2048 1408
|
||||
run_test 5 162 "Mixtral-8x7B-Down" 1 512 4096 14336
|
||||
run_test 6 102 "Qwen3-30B-A3B-Down" 4 512 2048 2048
|
||||
run_test 7 71 "DeepSeek-V2-Lite-GateUP" 4 512 2816 2048
|
||||
run_test 8 172 "Mixtral-8x22B-Down" 1 512 6144 16384
|
||||
run_test 9 63 "DeepSeek-V2-Lite-GateUP" 2 1024 2816 2048
|
||||
run_test 10 92 "Grok-2-Down" 1 512 8192 16384
|
||||
run_test 11 101 "Qwen3-30B-A3B-GateUP" 4 512 4096 2048
|
||||
run_test 12 82 "DeepSeek-V2-Lite-Down" 8 512 2048 1408
|
||||
run_test 13 112 "Qwen3-30B-A3B-Down" 8 512 2048 2048
|
||||
run_test 14 74 "DeepSeek-V2-Lite-Down" 4 1024 2048 1408
|
||||
run_test 15 164 "Mixtral-8x7B-Down" 1 1024 4096 14336
|
||||
run_test 16 104 "Qwen3-30B-A3B-Down" 4 1024 2048 2048
|
||||
run_test 17 66 "DeepSeek-V2-Lite-Down" 2 2048 2048 1408
|
||||
run_test 18 32 "DeepSeek-V2-Down" 5 512 5120 1536
|
||||
run_test 19 132 "Qwen3-235B-A22B-Down" 4 512 4096 4096
|
||||
run_test 20 81 "DeepSeek-V2-Lite-GateUP" 8 512 2816 2048
|
||||
run_test 21 31 "DeepSeek-V2-GateUP" 5 512 3072 5120
|
||||
run_test 22 42 "DeepSeek-V2-Down" 10 512 5120 1536
|
||||
run_test 23 161 "Mixtral-8x7B-GateUP" 1 512 28672 4096
|
||||
run_test 24 73 "DeepSeek-V2-Lite-GateUP" 4 1024 2816 2048
|
||||
run_test 25 171 "Mixtral-8x22B-GateUP" 1 512 32768 6144
|
||||
run_test 26 174 "Mixtral-8x22B-Down" 1 1024 6144 16384
|
||||
run_test 27 84 "DeepSeek-V2-Lite-Down" 8 1024 2048 1408
|
||||
run_test 28 34 "DeepSeek-V2-Down" 5 1024 5120 1536
|
||||
run_test 29 65 "DeepSeek-V2-Lite-GateUP" 2 2048 2816 2048
|
||||
run_test 30 83 "DeepSeek-V2-Lite-GateUP" 8 1024 2816 2048
|
||||
|
||||
echo "========================================================================================================"
|
||||
echo "All tests completed!"
|
||||
echo "========================================================================================================"
|
||||
Reference in New Issue
Block a user