udpating tests on mi300

This commit is contained in:
kyle-256
2026-01-09 10:37:27 +00:00
parent 726ddd64ad
commit 2e00471b10
6 changed files with 5617 additions and 35 deletions

89
analyze_configs_v3.py Normal file
View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
import re
with open('config_comparison_new.log', 'r') as f:
content = f.read()
# 按 Rank 分割
rank_blocks = re.split(r'={80,}\nRank', content)[1:]
# 所有 config 名称
all_configs = ['compute_v3', 'compute_v3_kb2', 'compute_v3_32x128', 'compute_v3_32x128_kb2',
'compute_v3_128x128', 'compute_v3_128x128_kb2', 'memory_intrawave', 'memory_intrawave_kb2']
print("=" * 220)
print("8 种 Config 性能对比 (TFLOPS) - 含 kbatch=2")
print("=" * 220)
# 统计
wins = {p: {c: 0 for c in all_configs} for p in ['Forward', 'grad_A', 'grad_B']}
for block in rank_blocks:
header = re.search(r'(\d+): (.+?) \(TestID=\d+\)\s+B=(\d+), M=(\d+), N=(\d+), K=(\d+)', block)
if not header:
continue
rank, case, B, M, N, K = header.groups()
# 按 config 分割
config_blocks = re.split(r'--- Config: (\S+) ---', block)
results = {}
for i in range(1, len(config_blocks), 2):
config_name = config_blocks[i]
config_content = config_blocks[i+1] if i+1 < len(config_blocks) else ""
results[config_name] = {}
fwd_match = re.search(r'\[Forward\].*?Perf:\s+([\d.]+) ms, ([\d.]+) TFlops', config_content, re.DOTALL)
if fwd_match:
results[config_name]['Forward'] = float(fwd_match.group(2))
grada_match = re.search(r'\[Backward grad_A\].*?Perf:\s+([\d.]+) ms, ([\d.]+) TFlops', config_content, re.DOTALL)
if grada_match:
results[config_name]['grad_A'] = float(grada_match.group(2))
gradb_match = re.search(r'\[Backward grad_B\].*?Perf:\s+([\d.]+) ms, ([\d.]+) TFlops', config_content, re.DOTALL)
if gradb_match:
results[config_name]['grad_B'] = float(gradb_match.group(2))
# 打印每个 rank 的结果
print(f"\nRank {rank}: {case} (B={B}, M={M}, N={N}, K={K})")
print("-" * 180)
print(f"{'Pass':<8} | {'v3':>7} {'v3_k2':>7} | {'32x128':>7} {'32_k2':>7} | {'128x128':>8} {'128_k2':>8} | {'intra':>7} {'intra_k2':>8} | {'Best':>18}")
print("-" * 180)
for pass_name in ['Forward', 'grad_A', 'grad_B']:
vals = {}
for cfg in all_configs:
vals[cfg] = results.get(cfg, {}).get(pass_name, 0)
best_val = max(vals.values()) if vals.values() else 0
best_cfg = [k for k, v in vals.items() if v == best_val][0] if best_val > 0 else 'N/A'
if best_val > 0:
wins[pass_name][best_cfg] += 1
v3 = vals.get('compute_v3', 0)
v3_k2 = vals.get('compute_v3_kb2', 0)
c32 = vals.get('compute_v3_32x128', 0)
c32_k2 = vals.get('compute_v3_32x128_kb2', 0)
c128 = vals.get('compute_v3_128x128', 0)
c128_k2 = vals.get('compute_v3_128x128_kb2', 0)
intra = vals.get('memory_intrawave', 0)
intra_k2 = vals.get('memory_intrawave_kb2', 0)
short_best = best_cfg.replace('compute_v3_', '').replace('memory_', '')
print(f"{pass_name:<8} | {v3:>7.1f} {v3_k2:>7.1f} | {c32:>7.1f} {c32_k2:>7.1f} | {c128:>8.1f} {c128_k2:>8.1f} | {intra:>7.1f} {intra_k2:>8.1f} | {short_best:>18}")
print("\n" + "=" * 120)
print("胜率统计 (30 cases)")
print("=" * 120)
print(f"{'Pass':<10} | {'v3':>6} {'v3_k2':>6} | {'32x128':>7} {'32_k2':>6} | {'128x128':>8} {'128_k2':>7} | {'intra':>6} {'intra_k2':>8}")
print("-" * 120)
for pass_name in ['Forward', 'grad_A', 'grad_B']:
w = wins[pass_name]
print(f"{pass_name:<10} | {w['compute_v3']:>6} {w['compute_v3_kb2']:>6} | {w['compute_v3_32x128']:>7} {w['compute_v3_32x128_kb2']:>6} | {w['compute_v3_128x128']:>8} {w['compute_v3_128x128_kb2']:>7} | {w['memory_intrawave']:>6} {w['memory_intrawave_kb2']:>8}")
total = {c: sum(wins[p][c] for p in wins) for c in all_configs}
print("-" * 120)
print(f"{'Total':<10} | {total['compute_v3']:>6} {total['compute_v3_kb2']:>6} | {total['compute_v3_32x128']:>7} {total['compute_v3_32x128_kb2']:>6} | {total['compute_v3_128x128']:>8} {total['compute_v3_128x128_kb2']:>7} | {total['memory_intrawave']:>6} {total['memory_intrawave_kb2']:>8}")
print("=" * 120)

1838
config_comparison_kbatch.log Normal file

File diff suppressed because it is too large Load Diff

3518
config_comparison_new.log Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -263,7 +263,7 @@ int run_grouped_gemm_example(int argc, char* argv[])
}
}
// Determine appropriate tile config based on N alignment
// Determine appropriate tile config based on N alignment and config selection
int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -276,11 +276,13 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
const int group_count = arg_parser.get_int("group_count");
const std::string config = arg_parser.get_str("config");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
// Check N alignment for all groups
bool all_n_mod_256 = true;
bool all_n_mod_128 = true;
bool all_n_mod_32 = true;
if(Ns.size() == static_cast<size_t>(group_count))
{
@@ -290,26 +292,69 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
all_n_mod_256 = false;
if(n % 128 != 0)
all_n_mod_128 = false;
if(n % 32 != 0)
all_n_mod_32 = false;
}
}
if(data_type == "bf16")
{
if(all_n_mod_256)
// Allow manual config selection via -config parameter
if(config == "memory_interwave")
{
std::cout << "[Config] Using 256x256 tile (N % 256 == 0)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_2<ck_tile::bf16_t>, ck_tile::bf16_t>(
if(!all_n_mod_32)
throw std::runtime_error("N must be multiple of 32 for memory_interwave config");
std::cout << "[Config] Using 128x32 tile (Memory Interwave)" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryInterwave<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(all_n_mod_128)
else if(config == "memory_intrawave")
{
std::cout << "[Config] Using 256x128 tile (N % 128 == 0, N % 256 != 0)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
if(!all_n_mod_32)
throw std::runtime_error("N must be multiple of 32 for memory_intrawave config");
std::cout << "[Config] Using 128x32 tile (Memory Intrawave)" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryIntrawave<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "compute_v3_32x128")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128 for compute_v3_32x128 config");
std::cout << "[Config] Using 32x128 tile (Compute V3)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_32x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "compute_v3_128x128")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128 for compute_v3_128x128 config");
std::cout << "[Config] Using 128x128 tile (Compute V3, kBlockPerCu=2)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_128x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "compute_v3" || config == "")
{
// Default: auto-select based on N alignment
if(all_n_mod_256)
{
std::cout << "[Config] Using 256x256 tile (N % 256 == 0)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_2<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(all_n_mod_128)
{
std::cout << "[Config] Using 256x128 tile (N % 128 == 0, N % 256 != 0)" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported N alignment for compute_v3 config.");
}
}
else
{
throw std::runtime_error("Unsupported error.");
throw std::runtime_error("Unknown config: " + config + ". Use: compute_v3, compute_v3_32x128, compute_v3_128x128, memory_interwave, memory_intrawave");
}
}
else

View File

@@ -149,6 +149,52 @@ struct GemmConfigComputeV3_256x128 : public GemmConfigBase
static constexpr int kBlockPerCu = 1;
};
// 32x128 tile config (small M, from FBGEMM)
template <typename PrecType>
struct GemmConfigComputeV3_32x128 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
};
// 128x128 tile config with kBlockPerCu=2 (from FBGEMM)
template <typename PrecType>
struct GemmConfigComputeV3_128x128 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
@@ -355,7 +401,8 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results")
.insert("config", "", "Tile config: compute_v3 (default), memory_interwave, memory_intrawave");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);

View File

@@ -1,7 +1,7 @@
#!/bin/bash
# BF16 表现最差的前 30 个 Case 测试脚本 (Forward + Backward)
# 按平均 TFLOPS 从低到高排序
# 测试三种 config: compute_v3, memory_interwave, memory_intrawave
BINARY="./build/bin/tile_example_grouped_gemm"
@@ -12,8 +12,15 @@ if [ ! -f "$BINARY" ]; then
exit 1
fi
# 可选参数: 指定要测试的 config (默认全部测试)
# 用法: ./run_worst_30_bf16_cases.sh [config]
# config: compute_v3, compute_v3_32x128, compute_v3_128x128, memory_intrawave, all (默认)
# 所有 config 都会测试 kbatch=1 和 kbatch=2
TEST_CONFIG=${1:-all}
echo "========================================================================================================"
echo "Running BF16 Worst 30 Cases - Grouped GEMM Benchmark (Forward + Backward)"
echo "Config: $TEST_CONFIG"
echo "========================================================================================================"
echo ""
@@ -32,7 +39,40 @@ repeat_param() {
echo "$result"
}
# 运行单个测试 (Forward + Backward)
# 运行单个 GEMM 测试
# 用法: run_gemm config kbatch a_layout b_layout Ms Ns Ks strides B label
run_gemm() {
local config=$1
local kbatch=$2
local a_layout=$3
local b_layout=$4
local Ms=$5
local Ns=$6
local Ks=$7
local strides=$8
local B=$9
local label=${10}
local config_arg=""
local kbatch_arg="-kbatch=$kbatch"
if [ "$config" = "compute_v3" ]; then
config_arg="" # default
else
config_arg="-config=$config"
fi
local config_name="$config"
if [ "$kbatch" = "2" ]; then
config_name="${config}_kb2"
fi
echo " [$config_name] $label"
$BINARY -Ms=$Ms -Ns=$Ns -Ks=$Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides \
-group_count=$B -prec=bf16 -validate=0 -a_layout=$a_layout -b_layout=$b_layout $config_arg $kbatch_arg 2>&1 | grep -E "Config|Perf"
}
# 运行单个测试 (Forward + Backward) 对比三种 config
# 用法: run_test rank testid case B M N K
run_test() {
local rank=$1
@@ -48,43 +88,48 @@ run_test() {
echo " B=$B, M=$M, N=$N, K=$K"
echo "========================================================================================================"
# ==================== Forward ====================
# Forward: (M, K) @ (K, N) = (M, N)
local fwd_Ms=$(repeat_param $M $B)
local fwd_Ns=$(repeat_param $N $B)
local fwd_Ks=$(repeat_param $K $B)
local strides=$(repeat_param 0 $B)
echo ""
echo "[Forward] GEMM: M=$M, N=$N, K=$K"
echo " Command: $BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
$BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
# ==================== Backward grad_A ====================
# grad_A = grad_Y @ W^T
# (M, N) @ (N, K) = (M, K)
# GEMM: M=M, N=K, K=N
local bwd_a_Ms=$(repeat_param $M $B)
local bwd_a_Ns=$(repeat_param $K $B)
local bwd_a_Ks=$(repeat_param $N $B)
echo ""
echo "[Backward grad_A] GEMM: M=$M, N=$K, K=$N"
echo " Command: $BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
$BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
# ==================== Backward grad_B ====================
# grad_B = X^T @ grad_Y
# (K, M) @ (M, N) = (K, N)
# GEMM: M=K, N=N, K=M
local bwd_b_Ms=$(repeat_param $K $B)
local bwd_b_Ns=$(repeat_param $N $B)
local bwd_b_Ks=$(repeat_param $M $B)
echo ""
echo "[Backward grad_B] GEMM: M=$K, N=$N, K=$M"
echo " Command: $BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
$BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
# 确定要测试的 configs
local configs=""
if [ "$TEST_CONFIG" = "all" ]; then
configs="compute_v3 compute_v3_32x128 compute_v3_128x128 memory_intrawave"
else
configs="$TEST_CONFIG"
fi
# 测试每个 config 的 kbatch=1 和 kbatch=2
for cfg in $configs; do
for kbatch in 1 2; do
local cfg_name="$cfg"
if [ "$kbatch" = "2" ]; then
cfg_name="${cfg}_kb2"
fi
echo ""
echo "--- Config: $cfg_name ---"
echo "[Forward] M=$M, N=$N, K=$K (a_layout=R, b_layout=C)"
run_gemm $cfg $kbatch R C "$fwd_Ms" "$fwd_Ns" "$fwd_Ks" "$strides" $B "Forward"
echo "[Backward grad_A] M=$M, N=$K, K=$N (a_layout=R, b_layout=R)"
run_gemm $cfg $kbatch R R "$bwd_a_Ms" "$bwd_a_Ns" "$bwd_a_Ks" "$strides" $B "grad_A"
echo "[Backward grad_B] M=$K, N=$N, K=$M (a_layout=C, b_layout=R)"
run_gemm $cfg $kbatch C R "$bwd_b_Ms" "$bwd_b_Ns" "$bwd_b_Ks" "$strides" $B "grad_B"
done
done
echo ""
}