mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
udpating tests on mi300
This commit is contained in:
89
analyze_configs_v3.py
Normal file
89
analyze_configs_v3.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
import re
|
||||
|
||||
with open('config_comparison_new.log', 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# 按 Rank 分割
|
||||
rank_blocks = re.split(r'={80,}\nRank', content)[1:]
|
||||
|
||||
# 所有 config 名称
|
||||
all_configs = ['compute_v3', 'compute_v3_kb2', 'compute_v3_32x128', 'compute_v3_32x128_kb2',
|
||||
'compute_v3_128x128', 'compute_v3_128x128_kb2', 'memory_intrawave', 'memory_intrawave_kb2']
|
||||
|
||||
print("=" * 220)
|
||||
print("8 种 Config 性能对比 (TFLOPS) - 含 kbatch=2")
|
||||
print("=" * 220)
|
||||
|
||||
# 统计
|
||||
wins = {p: {c: 0 for c in all_configs} for p in ['Forward', 'grad_A', 'grad_B']}
|
||||
|
||||
for block in rank_blocks:
|
||||
header = re.search(r'(\d+): (.+?) \(TestID=\d+\)\s+B=(\d+), M=(\d+), N=(\d+), K=(\d+)', block)
|
||||
if not header:
|
||||
continue
|
||||
rank, case, B, M, N, K = header.groups()
|
||||
|
||||
# 按 config 分割
|
||||
config_blocks = re.split(r'--- Config: (\S+) ---', block)
|
||||
results = {}
|
||||
|
||||
for i in range(1, len(config_blocks), 2):
|
||||
config_name = config_blocks[i]
|
||||
config_content = config_blocks[i+1] if i+1 < len(config_blocks) else ""
|
||||
results[config_name] = {}
|
||||
|
||||
fwd_match = re.search(r'\[Forward\].*?Perf:\s+([\d.]+) ms, ([\d.]+) TFlops', config_content, re.DOTALL)
|
||||
if fwd_match:
|
||||
results[config_name]['Forward'] = float(fwd_match.group(2))
|
||||
|
||||
grada_match = re.search(r'\[Backward grad_A\].*?Perf:\s+([\d.]+) ms, ([\d.]+) TFlops', config_content, re.DOTALL)
|
||||
if grada_match:
|
||||
results[config_name]['grad_A'] = float(grada_match.group(2))
|
||||
|
||||
gradb_match = re.search(r'\[Backward grad_B\].*?Perf:\s+([\d.]+) ms, ([\d.]+) TFlops', config_content, re.DOTALL)
|
||||
if gradb_match:
|
||||
results[config_name]['grad_B'] = float(gradb_match.group(2))
|
||||
|
||||
# 打印每个 rank 的结果
|
||||
print(f"\nRank {rank}: {case} (B={B}, M={M}, N={N}, K={K})")
|
||||
print("-" * 180)
|
||||
print(f"{'Pass':<8} | {'v3':>7} {'v3_k2':>7} | {'32x128':>7} {'32_k2':>7} | {'128x128':>8} {'128_k2':>8} | {'intra':>7} {'intra_k2':>8} | {'Best':>18}")
|
||||
print("-" * 180)
|
||||
|
||||
for pass_name in ['Forward', 'grad_A', 'grad_B']:
|
||||
vals = {}
|
||||
for cfg in all_configs:
|
||||
vals[cfg] = results.get(cfg, {}).get(pass_name, 0)
|
||||
|
||||
best_val = max(vals.values()) if vals.values() else 0
|
||||
best_cfg = [k for k, v in vals.items() if v == best_val][0] if best_val > 0 else 'N/A'
|
||||
|
||||
if best_val > 0:
|
||||
wins[pass_name][best_cfg] += 1
|
||||
|
||||
v3 = vals.get('compute_v3', 0)
|
||||
v3_k2 = vals.get('compute_v3_kb2', 0)
|
||||
c32 = vals.get('compute_v3_32x128', 0)
|
||||
c32_k2 = vals.get('compute_v3_32x128_kb2', 0)
|
||||
c128 = vals.get('compute_v3_128x128', 0)
|
||||
c128_k2 = vals.get('compute_v3_128x128_kb2', 0)
|
||||
intra = vals.get('memory_intrawave', 0)
|
||||
intra_k2 = vals.get('memory_intrawave_kb2', 0)
|
||||
|
||||
short_best = best_cfg.replace('compute_v3_', '').replace('memory_', '')
|
||||
print(f"{pass_name:<8} | {v3:>7.1f} {v3_k2:>7.1f} | {c32:>7.1f} {c32_k2:>7.1f} | {c128:>8.1f} {c128_k2:>8.1f} | {intra:>7.1f} {intra_k2:>8.1f} | {short_best:>18}")
|
||||
|
||||
print("\n" + "=" * 120)
|
||||
print("胜率统计 (30 cases)")
|
||||
print("=" * 120)
|
||||
print(f"{'Pass':<10} | {'v3':>6} {'v3_k2':>6} | {'32x128':>7} {'32_k2':>6} | {'128x128':>8} {'128_k2':>7} | {'intra':>6} {'intra_k2':>8}")
|
||||
print("-" * 120)
|
||||
for pass_name in ['Forward', 'grad_A', 'grad_B']:
|
||||
w = wins[pass_name]
|
||||
print(f"{pass_name:<10} | {w['compute_v3']:>6} {w['compute_v3_kb2']:>6} | {w['compute_v3_32x128']:>7} {w['compute_v3_32x128_kb2']:>6} | {w['compute_v3_128x128']:>8} {w['compute_v3_128x128_kb2']:>7} | {w['memory_intrawave']:>6} {w['memory_intrawave_kb2']:>8}")
|
||||
|
||||
total = {c: sum(wins[p][c] for p in wins) for c in all_configs}
|
||||
print("-" * 120)
|
||||
print(f"{'Total':<10} | {total['compute_v3']:>6} {total['compute_v3_kb2']:>6} | {total['compute_v3_32x128']:>7} {total['compute_v3_32x128_kb2']:>6} | {total['compute_v3_128x128']:>8} {total['compute_v3_128x128_kb2']:>7} | {total['memory_intrawave']:>6} {total['memory_intrawave_kb2']:>8}")
|
||||
print("=" * 120)
|
||||
1838
config_comparison_kbatch.log
Normal file
1838
config_comparison_kbatch.log
Normal file
File diff suppressed because it is too large
Load Diff
3518
config_comparison_new.log
Normal file
3518
config_comparison_new.log
Normal file
File diff suppressed because it is too large
Load Diff
@@ -263,7 +263,7 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
// Determine appropriate tile config based on N alignment
|
||||
// Determine appropriate tile config based on N alignment and config selection
|
||||
int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -276,11 +276,13 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const std::string config = arg_parser.get_str("config");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
|
||||
// Check N alignment for all groups
|
||||
bool all_n_mod_256 = true;
|
||||
bool all_n_mod_128 = true;
|
||||
bool all_n_mod_32 = true;
|
||||
|
||||
if(Ns.size() == static_cast<size_t>(group_count))
|
||||
{
|
||||
@@ -290,26 +292,69 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
|
||||
all_n_mod_256 = false;
|
||||
if(n % 128 != 0)
|
||||
all_n_mod_128 = false;
|
||||
if(n % 32 != 0)
|
||||
all_n_mod_32 = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
if(all_n_mod_256)
|
||||
// Allow manual config selection via -config parameter
|
||||
if(config == "memory_interwave")
|
||||
{
|
||||
std::cout << "[Config] Using 256x256 tile (N % 256 == 0)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_2<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
if(!all_n_mod_32)
|
||||
throw std::runtime_error("N must be multiple of 32 for memory_interwave config");
|
||||
std::cout << "[Config] Using 128x32 tile (Memory Interwave)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryInterwave<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(all_n_mod_128)
|
||||
else if(config == "memory_intrawave")
|
||||
{
|
||||
std::cout << "[Config] Using 256x128 tile (N % 128 == 0, N % 256 != 0)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
if(!all_n_mod_32)
|
||||
throw std::runtime_error("N must be multiple of 32 for memory_intrawave config");
|
||||
std::cout << "[Config] Using 128x32 tile (Memory Intrawave)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryIntrawave<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "compute_v3_32x128")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128 for compute_v3_32x128 config");
|
||||
std::cout << "[Config] Using 32x128 tile (Compute V3)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_32x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "compute_v3_128x128")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128 for compute_v3_128x128 config");
|
||||
std::cout << "[Config] Using 128x128 tile (Compute V3, kBlockPerCu=2)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_128x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "compute_v3" || config == "")
|
||||
{
|
||||
// Default: auto-select based on N alignment
|
||||
if(all_n_mod_256)
|
||||
{
|
||||
std::cout << "[Config] Using 256x256 tile (N % 256 == 0)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_2<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(all_n_mod_128)
|
||||
{
|
||||
std::cout << "[Config] Using 256x128 tile (N % 128 == 0, N % 256 != 0)" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported N alignment for compute_v3 config.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported error.");
|
||||
throw std::runtime_error("Unknown config: " + config + ". Use: compute_v3, compute_v3_32x128, compute_v3_128x128, memory_interwave, memory_intrawave");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -149,6 +149,52 @@ struct GemmConfigComputeV3_256x128 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
// 32x128 tile config (small M, from FBGEMM)
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_32x128 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
// 128x128 tile config with kBlockPerCu=2 (from FBGEMM)
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_128x128 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
@@ -355,7 +401,8 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
|
||||
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results")
|
||||
.insert("config", "", "Tile config: compute_v3 (default), memory_interwave, memory_intrawave");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_pair(result, arg_parser);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# BF16 表现最差的前 30 个 Case 测试脚本 (Forward + Backward)
|
||||
# 按平均 TFLOPS 从低到高排序
|
||||
# 测试三种 config: compute_v3, memory_interwave, memory_intrawave
|
||||
|
||||
BINARY="./build/bin/tile_example_grouped_gemm"
|
||||
|
||||
@@ -12,8 +12,15 @@ if [ ! -f "$BINARY" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 可选参数: 指定要测试的 config (默认全部测试)
|
||||
# 用法: ./run_worst_30_bf16_cases.sh [config]
|
||||
# config: compute_v3, compute_v3_32x128, compute_v3_128x128, memory_intrawave, all (默认)
|
||||
# 所有 config 都会测试 kbatch=1 和 kbatch=2
|
||||
TEST_CONFIG=${1:-all}
|
||||
|
||||
echo "========================================================================================================"
|
||||
echo "Running BF16 Worst 30 Cases - Grouped GEMM Benchmark (Forward + Backward)"
|
||||
echo "Config: $TEST_CONFIG"
|
||||
echo "========================================================================================================"
|
||||
echo ""
|
||||
|
||||
@@ -32,7 +39,40 @@ repeat_param() {
|
||||
echo "$result"
|
||||
}
|
||||
|
||||
# 运行单个测试 (Forward + Backward)
|
||||
# 运行单个 GEMM 测试
|
||||
# 用法: run_gemm config kbatch a_layout b_layout Ms Ns Ks strides B label
|
||||
run_gemm() {
|
||||
local config=$1
|
||||
local kbatch=$2
|
||||
local a_layout=$3
|
||||
local b_layout=$4
|
||||
local Ms=$5
|
||||
local Ns=$6
|
||||
local Ks=$7
|
||||
local strides=$8
|
||||
local B=$9
|
||||
local label=${10}
|
||||
|
||||
local config_arg=""
|
||||
local kbatch_arg="-kbatch=$kbatch"
|
||||
|
||||
if [ "$config" = "compute_v3" ]; then
|
||||
config_arg="" # default
|
||||
else
|
||||
config_arg="-config=$config"
|
||||
fi
|
||||
|
||||
local config_name="$config"
|
||||
if [ "$kbatch" = "2" ]; then
|
||||
config_name="${config}_kb2"
|
||||
fi
|
||||
|
||||
echo " [$config_name] $label"
|
||||
$BINARY -Ms=$Ms -Ns=$Ns -Ks=$Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides \
|
||||
-group_count=$B -prec=bf16 -validate=0 -a_layout=$a_layout -b_layout=$b_layout $config_arg $kbatch_arg 2>&1 | grep -E "Config|Perf"
|
||||
}
|
||||
|
||||
# 运行单个测试 (Forward + Backward) 对比三种 config
|
||||
# 用法: run_test rank testid case B M N K
|
||||
run_test() {
|
||||
local rank=$1
|
||||
@@ -48,43 +88,48 @@ run_test() {
|
||||
echo " B=$B, M=$M, N=$N, K=$K"
|
||||
echo "========================================================================================================"
|
||||
|
||||
# ==================== Forward ====================
|
||||
# Forward: (M, K) @ (K, N) = (M, N)
|
||||
local fwd_Ms=$(repeat_param $M $B)
|
||||
local fwd_Ns=$(repeat_param $N $B)
|
||||
local fwd_Ks=$(repeat_param $K $B)
|
||||
local strides=$(repeat_param 0 $B)
|
||||
|
||||
echo ""
|
||||
echo "[Forward] GEMM: M=$M, N=$N, K=$K"
|
||||
echo " Command: $BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
|
||||
$BINARY -Ms=$fwd_Ms -Ns=$fwd_Ns -Ks=$fwd_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
|
||||
|
||||
# ==================== Backward grad_A ====================
|
||||
# grad_A = grad_Y @ W^T
|
||||
# (M, N) @ (N, K) = (M, K)
|
||||
# GEMM: M=M, N=K, K=N
|
||||
local bwd_a_Ms=$(repeat_param $M $B)
|
||||
local bwd_a_Ns=$(repeat_param $K $B)
|
||||
local bwd_a_Ks=$(repeat_param $N $B)
|
||||
|
||||
echo ""
|
||||
echo "[Backward grad_A] GEMM: M=$M, N=$K, K=$N"
|
||||
echo " Command: $BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
|
||||
$BINARY -Ms=$bwd_a_Ms -Ns=$bwd_a_Ns -Ks=$bwd_a_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
|
||||
|
||||
# ==================== Backward grad_B ====================
|
||||
# grad_B = X^T @ grad_Y
|
||||
# (K, M) @ (M, N) = (K, N)
|
||||
# GEMM: M=K, N=N, K=M
|
||||
local bwd_b_Ms=$(repeat_param $K $B)
|
||||
local bwd_b_Ns=$(repeat_param $N $B)
|
||||
local bwd_b_Ks=$(repeat_param $M $B)
|
||||
|
||||
echo ""
|
||||
echo "[Backward grad_B] GEMM: M=$K, N=$N, K=$M"
|
||||
echo " Command: $BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1"
|
||||
$BINARY -Ms=$bwd_b_Ms -Ns=$bwd_b_Ns -Ks=$bwd_b_Ks -stride_As=$strides -stride_Bs=$strides -stride_Cs=$strides -group_count=$B -prec=bf16 -validate=1
|
||||
# 确定要测试的 configs
|
||||
local configs=""
|
||||
if [ "$TEST_CONFIG" = "all" ]; then
|
||||
configs="compute_v3 compute_v3_32x128 compute_v3_128x128 memory_intrawave"
|
||||
else
|
||||
configs="$TEST_CONFIG"
|
||||
fi
|
||||
|
||||
# 测试每个 config 的 kbatch=1 和 kbatch=2
|
||||
for cfg in $configs; do
|
||||
for kbatch in 1 2; do
|
||||
local cfg_name="$cfg"
|
||||
if [ "$kbatch" = "2" ]; then
|
||||
cfg_name="${cfg}_kb2"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- Config: $cfg_name ---"
|
||||
|
||||
echo "[Forward] M=$M, N=$N, K=$K (a_layout=R, b_layout=C)"
|
||||
run_gemm $cfg $kbatch R C "$fwd_Ms" "$fwd_Ns" "$fwd_Ks" "$strides" $B "Forward"
|
||||
|
||||
echo "[Backward grad_A] M=$M, N=$K, K=$N (a_layout=R, b_layout=R)"
|
||||
run_gemm $cfg $kbatch R R "$bwd_a_Ms" "$bwd_a_Ns" "$bwd_a_Ks" "$strides" $B "grad_A"
|
||||
|
||||
echo "[Backward grad_B] M=$K, N=$N, K=$M (a_layout=C, b_layout=R)"
|
||||
run_gemm $cfg $kbatch C R "$bwd_b_Ms" "$bwd_b_Ns" "$bwd_b_Ks" "$strides" $B "grad_B"
|
||||
done
|
||||
done
|
||||
|
||||
echo ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user