updating test

This commit is contained in:
kyle-256
2026-01-16 07:05:05 +00:00
parent e2311b8dc7
commit f6330af670
2 changed files with 32 additions and 140 deletions

View File

@@ -282,9 +282,7 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
// Check N alignment for all groups
bool all_n_mod_256 = true;
bool all_n_mod_128 = true;
bool all_n_mod_64 = true;
if(Ns.size() == static_cast<size_t>(group_count))
{
for(const auto& n : Ns)
@@ -293,118 +291,12 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
all_n_mod_256 = false;
if(n % 128 != 0)
all_n_mod_128 = false;
if(n % 64 != 0)
all_n_mod_64 = false;
}
}
if(data_type == "bf16")
{
// Memory pipeline configs
if(config == "memory_interwave")
{
std::cout << "[Config] Memory Interwave 128x32" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryInterwave<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "memory_intrawave")
{
std::cout << "[Config] Memory Intrawave 128x32" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryIntrawave<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "mem_inter_128x128")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128");
std::cout << "[Config] Memory Interwave 128x128" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryInterwave_128x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "mem_intra_128x128")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128");
std::cout << "[Config] Memory Intrawave 128x128" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryIntrawave_128x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "mem_inter_256x128")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128");
std::cout << "[Config] Memory Interwave 256x128" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryInterwave_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "mem_inter_256x256")
{
if(!all_n_mod_256)
throw std::runtime_error("N must be multiple of 256");
std::cout << "[Config] Memory Interwave 256x256" << std::endl;
return run_gemm_example_prec_type<GemmConfigMemoryInterwave_256x256<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
// 128x128 configs
else if(config == "v3_128x128_16_k1" || config == "128w16k1")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128");
std::cout << "[Config] 128x128, warp=16, kBlockPerCu=1" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_128x128_16_k1<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "v3_128x128_16_k2" || config == "128w16k2")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128");
std::cout << "[Config] 128x128, warp=16, kBlockPerCu=2" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_128x128_16_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
// 64x64 config
else if(config == "v3_64x64_16_k1" || config == "64w16k1")
{
if(!all_n_mod_64)
throw std::runtime_error("N must be multiple of 64");
std::cout << "[Config] 64x64, warp=16, kBlockPerCu=1" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_64x64_16_k1<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
// grad_B optimized configs
else if(config == "256x256_k2")
{
if(!all_n_mod_256)
throw std::runtime_error("N must be multiple of 256");
std::cout << "[Config] 256x256, kBlockPerCu=2" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_256x256_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "128x256")
{
if(!all_n_mod_256)
throw std::runtime_error("N must be multiple of 256");
std::cout << "[Config] 128x256" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_128x256<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "256x128_k2")
{
if(!all_n_mod_128)
throw std::runtime_error("N must be multiple of 128");
std::cout << "[Config] 256x128, kBlockPerCu=2" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "128x256_k2")
{
if(!all_n_mod_256)
throw std::runtime_error("N must be multiple of 256");
std::cout << "[Config] 128x256, kBlockPerCu=2" << std::endl;
return run_gemm_example_prec_type<GemmConfigComputeV3_128x256_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(config == "compute_v3" || config == "")
if(config == "compute_v3" || config == "")
{
// Default: auto-select based on N alignment
if(all_n_mod_256)
@@ -426,7 +318,7 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
}
else
{
throw std::runtime_error("Unknown config: " + config + ". Use: compute_v3, compute_v3_32x128, compute_v3_128x128, memory_interwave, memory_intrawave");
throw std::runtime_error("Unknown config: " + config + ". Use: compute_v3, compute_v3_32x128, compute_v3_128x128");
}
}
else

60
run_worst_30_bf16_cases.sh Executable file → Normal file
View File

@@ -135,36 +135,36 @@ run_test() {
}
# 运行所有30个测试用例
run_test 1 62 "DeepSeek-V2-Lite-Down" 2 512 2048 1408
run_test 2 61 "DeepSeek-V2-Lite-GateUP" 2 512 2816 2048
run_test 3 72 "DeepSeek-V2-Lite-Down" 4 512 2048 1408
run_test 4 64 "DeepSeek-V2-Lite-Down" 2 1024 2048 1408
run_test 5 162 "Mixtral-8x7B-Down" 1 512 4096 14336
run_test 6 102 "Qwen3-30B-A3B-Down" 4 512 2048 2048
run_test 7 71 "DeepSeek-V2-Lite-GateUP" 4 512 2816 2048
run_test 8 172 "Mixtral-8x22B-Down" 1 512 6144 16384
run_test 9 63 "DeepSeek-V2-Lite-GateUP" 2 1024 2816 2048
run_test 10 92 "Grok-2-Down" 1 512 8192 16384
run_test 11 101 "Qwen3-30B-A3B-GateUP" 4 512 4096 2048
run_test 12 82 "DeepSeek-V2-Lite-Down" 8 512 2048 1408
run_test 13 112 "Qwen3-30B-A3B-Down" 8 512 2048 2048
run_test 14 74 "DeepSeek-V2-Lite-Down" 4 1024 2048 1408
run_test 15 164 "Mixtral-8x7B-Down" 1 1024 4096 14336
run_test 16 104 "Qwen3-30B-A3B-Down" 4 1024 2048 2048
run_test 17 66 "DeepSeek-V2-Lite-Down" 2 2048 2048 1408
run_test 18 32 "DeepSeek-V2-Down" 5 512 5120 1536
run_test 19 132 "Qwen3-235B-A22B-Down" 4 512 4096 4096
run_test 20 81 "DeepSeek-V2-Lite-GateUP" 8 512 2816 2048
run_test 21 31 "DeepSeek-V2-GateUP" 5 512 3072 5120
run_test 22 42 "DeepSeek-V2-Down" 10 512 5120 1536
run_test 23 161 "Mixtral-8x7B-GateUP" 1 512 28672 4096
run_test 24 73 "DeepSeek-V2-Lite-GateUP" 4 1024 2816 2048
run_test 25 171 "Mixtral-8x22B-GateUP" 1 512 32768 6144
run_test 26 174 "Mixtral-8x22B-Down" 1 1024 6144 16384
run_test 27 84 "DeepSeek-V2-Lite-Down" 8 1024 2048 1408
run_test 28 34 "DeepSeek-V2-Down" 5 1024 5120 1536
run_test 29 65 "DeepSeek-V2-Lite-GateUP" 2 2048 2816 2048
run_test 30 83 "DeepSeek-V2-Lite-GateUP" 8 1024 2816 2048
run_test 1 62 "DeepSeek-V2-Lite-Down" 2 1024 2048 1408
run_test 2 61 "DeepSeek-V2-Lite-GateUP" 2 1024 2816 2048
run_test 3 72 "DeepSeek-V2-Lite-Down" 4 1024 2048 1408
run_test 4 64 "DeepSeek-V2-Lite-Down" 2 2048 2048 1408
run_test 5 162 "Mixtral-8x7B-Down" 1 1024 4096 14336
run_test 6 102 "Qwen3-30B-A3B-Down" 4 1024 2048 2048
run_test 7 71 "DeepSeek-V2-Lite-GateUP" 4 1024 2816 2048
run_test 8 82 "DeepSeek-V2-Lite-Down" 8 1024 2048 1408
run_test 9 32 "DeepSeek-V2-Down" 5 1024 5120 1536
run_test 10 63 "DeepSeek-V2-Lite-GateUP" 2 2048 2816 2048
run_test 11 74 "DeepSeek-V2-Lite-Down" 4 2048 2048 1408
run_test 12 172 "Mixtral-8x22B-Down" 1 1024 6144 16384
run_test 13 161 "Mixtral-8x7B-GateUP" 1 1024 28672 4096
run_test 14 81 "DeepSeek-V2-Lite-GateUP" 8 1024 2816 2048
run_test 15 66 "DeepSeek-V2-Lite-Down" 2 4096 2048 1408
run_test 16 42 "DeepSeek-V2-Down" 10 1024 5120 1536
run_test 17 101 "Qwen3-30B-A3B-GateUP" 4 1024 4096 2048
run_test 18 73 "DeepSeek-V2-Lite-GateUP" 4 2048 2816 2048
run_test 19 112 "Qwen3-30B-A3B-Down" 8 1024 2048 2048
run_test 20 122 "Qwen3-30B-A3B-Down" 16 1024 2048 2048
run_test 21 84 "DeepSeek-V2-Lite-Down" 8 2048 2048 1408
run_test 22 52 "DeepSeek-V2-Down" 20 1024 5120 1536
run_test 23 182 "MoE-1T-Down" 7 1024 8192 1920
run_test 24 34 "DeepSeek-V2-Down" 5 2048 5120 1536
run_test 25 76 "DeepSeek-V2-Lite-Down" 4 4096 2048 1408
run_test 26 92 "Grok-2-Down" 1 1024 8192 16384
run_test 27 65 "DeepSeek-V2-Lite-GateUP" 2 4096 2816 2048
run_test 28 68 "DeepSeek-V2-Lite-Down" 2 8192 2048 1408
run_test 29 80 "DeepSeek-V2-Lite-Down" 4 16384 2048 1408
run_test 30 171 "Mixtral-8x22B-GateUP" 1 1024 32768 6144
echo "========================================================================================================"
echo "All tests completed!"