mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
updating test
This commit is contained in:
@@ -282,9 +282,7 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
|
||||
// Check N alignment for all groups
|
||||
bool all_n_mod_256 = true;
|
||||
bool all_n_mod_128 = true;
|
||||
bool all_n_mod_64 = true;
|
||||
|
||||
|
||||
if(Ns.size() == static_cast<size_t>(group_count))
|
||||
{
|
||||
for(const auto& n : Ns)
|
||||
@@ -293,118 +291,12 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
|
||||
all_n_mod_256 = false;
|
||||
if(n % 128 != 0)
|
||||
all_n_mod_128 = false;
|
||||
if(n % 64 != 0)
|
||||
all_n_mod_64 = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
// Memory pipeline configs
|
||||
if(config == "memory_interwave")
|
||||
{
|
||||
std::cout << "[Config] Memory Interwave 128x32" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryInterwave<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "memory_intrawave")
|
||||
{
|
||||
std::cout << "[Config] Memory Intrawave 128x32" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryIntrawave<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "mem_inter_128x128")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128");
|
||||
std::cout << "[Config] Memory Interwave 128x128" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryInterwave_128x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "mem_intra_128x128")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128");
|
||||
std::cout << "[Config] Memory Intrawave 128x128" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryIntrawave_128x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "mem_inter_256x128")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128");
|
||||
std::cout << "[Config] Memory Interwave 256x128" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryInterwave_256x128<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "mem_inter_256x256")
|
||||
{
|
||||
if(!all_n_mod_256)
|
||||
throw std::runtime_error("N must be multiple of 256");
|
||||
std::cout << "[Config] Memory Interwave 256x256" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigMemoryInterwave_256x256<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
// 128x128 configs
|
||||
else if(config == "v3_128x128_16_k1" || config == "128w16k1")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128");
|
||||
std::cout << "[Config] 128x128, warp=16, kBlockPerCu=1" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_128x128_16_k1<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "v3_128x128_16_k2" || config == "128w16k2")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128");
|
||||
std::cout << "[Config] 128x128, warp=16, kBlockPerCu=2" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_128x128_16_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
// 64x64 config
|
||||
else if(config == "v3_64x64_16_k1" || config == "64w16k1")
|
||||
{
|
||||
if(!all_n_mod_64)
|
||||
throw std::runtime_error("N must be multiple of 64");
|
||||
std::cout << "[Config] 64x64, warp=16, kBlockPerCu=1" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_64x64_16_k1<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
// grad_B optimized configs
|
||||
else if(config == "256x256_k2")
|
||||
{
|
||||
if(!all_n_mod_256)
|
||||
throw std::runtime_error("N must be multiple of 256");
|
||||
std::cout << "[Config] 256x256, kBlockPerCu=2" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_256x256_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "128x256")
|
||||
{
|
||||
if(!all_n_mod_256)
|
||||
throw std::runtime_error("N must be multiple of 256");
|
||||
std::cout << "[Config] 128x256" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_128x256<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "256x128_k2")
|
||||
{
|
||||
if(!all_n_mod_128)
|
||||
throw std::runtime_error("N must be multiple of 128");
|
||||
std::cout << "[Config] 256x128, kBlockPerCu=2" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_256x128_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "128x256_k2")
|
||||
{
|
||||
if(!all_n_mod_256)
|
||||
throw std::runtime_error("N must be multiple of 256");
|
||||
std::cout << "[Config] 128x256, kBlockPerCu=2" << std::endl;
|
||||
return run_gemm_example_prec_type<GemmConfigComputeV3_128x256_k2<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(config == "compute_v3" || config == "")
|
||||
if(config == "compute_v3" || config == "")
|
||||
{
|
||||
// Default: auto-select based on N alignment
|
||||
if(all_n_mod_256)
|
||||
@@ -426,7 +318,7 @@ int run_grouped_gemm_example_with_n_check(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unknown config: " + config + ". Use: compute_v3, compute_v3_32x128, compute_v3_128x128, memory_interwave, memory_intrawave");
|
||||
throw std::runtime_error("Unknown config: " + config + ". Use: compute_v3, compute_v3_32x128, compute_v3_128x128");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
60
run_worst_30_bf16_cases.sh
Executable file → Normal file
60
run_worst_30_bf16_cases.sh
Executable file → Normal file
@@ -135,36 +135,36 @@ run_test() {
|
||||
}
|
||||
|
||||
# 运行所有30个测试用例
|
||||
run_test 1 62 "DeepSeek-V2-Lite-Down" 2 512 2048 1408
|
||||
run_test 2 61 "DeepSeek-V2-Lite-GateUP" 2 512 2816 2048
|
||||
run_test 3 72 "DeepSeek-V2-Lite-Down" 4 512 2048 1408
|
||||
run_test 4 64 "DeepSeek-V2-Lite-Down" 2 1024 2048 1408
|
||||
run_test 5 162 "Mixtral-8x7B-Down" 1 512 4096 14336
|
||||
run_test 6 102 "Qwen3-30B-A3B-Down" 4 512 2048 2048
|
||||
run_test 7 71 "DeepSeek-V2-Lite-GateUP" 4 512 2816 2048
|
||||
run_test 8 172 "Mixtral-8x22B-Down" 1 512 6144 16384
|
||||
run_test 9 63 "DeepSeek-V2-Lite-GateUP" 2 1024 2816 2048
|
||||
run_test 10 92 "Grok-2-Down" 1 512 8192 16384
|
||||
run_test 11 101 "Qwen3-30B-A3B-GateUP" 4 512 4096 2048
|
||||
run_test 12 82 "DeepSeek-V2-Lite-Down" 8 512 2048 1408
|
||||
run_test 13 112 "Qwen3-30B-A3B-Down" 8 512 2048 2048
|
||||
run_test 14 74 "DeepSeek-V2-Lite-Down" 4 1024 2048 1408
|
||||
run_test 15 164 "Mixtral-8x7B-Down" 1 1024 4096 14336
|
||||
run_test 16 104 "Qwen3-30B-A3B-Down" 4 1024 2048 2048
|
||||
run_test 17 66 "DeepSeek-V2-Lite-Down" 2 2048 2048 1408
|
||||
run_test 18 32 "DeepSeek-V2-Down" 5 512 5120 1536
|
||||
run_test 19 132 "Qwen3-235B-A22B-Down" 4 512 4096 4096
|
||||
run_test 20 81 "DeepSeek-V2-Lite-GateUP" 8 512 2816 2048
|
||||
run_test 21 31 "DeepSeek-V2-GateUP" 5 512 3072 5120
|
||||
run_test 22 42 "DeepSeek-V2-Down" 10 512 5120 1536
|
||||
run_test 23 161 "Mixtral-8x7B-GateUP" 1 512 28672 4096
|
||||
run_test 24 73 "DeepSeek-V2-Lite-GateUP" 4 1024 2816 2048
|
||||
run_test 25 171 "Mixtral-8x22B-GateUP" 1 512 32768 6144
|
||||
run_test 26 174 "Mixtral-8x22B-Down" 1 1024 6144 16384
|
||||
run_test 27 84 "DeepSeek-V2-Lite-Down" 8 1024 2048 1408
|
||||
run_test 28 34 "DeepSeek-V2-Down" 5 1024 5120 1536
|
||||
run_test 29 65 "DeepSeek-V2-Lite-GateUP" 2 2048 2816 2048
|
||||
run_test 30 83 "DeepSeek-V2-Lite-GateUP" 8 1024 2816 2048
|
||||
run_test 1 62 "DeepSeek-V2-Lite-Down" 2 1024 2048 1408
|
||||
run_test 2 61 "DeepSeek-V2-Lite-GateUP" 2 1024 2816 2048
|
||||
run_test 3 72 "DeepSeek-V2-Lite-Down" 4 1024 2048 1408
|
||||
run_test 4 64 "DeepSeek-V2-Lite-Down" 2 2048 2048 1408
|
||||
run_test 5 162 "Mixtral-8x7B-Down" 1 1024 4096 14336
|
||||
run_test 6 102 "Qwen3-30B-A3B-Down" 4 1024 2048 2048
|
||||
run_test 7 71 "DeepSeek-V2-Lite-GateUP" 4 1024 2816 2048
|
||||
run_test 8 82 "DeepSeek-V2-Lite-Down" 8 1024 2048 1408
|
||||
run_test 9 32 "DeepSeek-V2-Down" 5 1024 5120 1536
|
||||
run_test 10 63 "DeepSeek-V2-Lite-GateUP" 2 2048 2816 2048
|
||||
run_test 11 74 "DeepSeek-V2-Lite-Down" 4 2048 2048 1408
|
||||
run_test 12 172 "Mixtral-8x22B-Down" 1 1024 6144 16384
|
||||
run_test 13 161 "Mixtral-8x7B-GateUP" 1 1024 28672 4096
|
||||
run_test 14 81 "DeepSeek-V2-Lite-GateUP" 8 1024 2816 2048
|
||||
run_test 15 66 "DeepSeek-V2-Lite-Down" 2 4096 2048 1408
|
||||
run_test 16 42 "DeepSeek-V2-Down" 10 1024 5120 1536
|
||||
run_test 17 101 "Qwen3-30B-A3B-GateUP" 4 1024 4096 2048
|
||||
run_test 18 73 "DeepSeek-V2-Lite-GateUP" 4 2048 2816 2048
|
||||
run_test 19 112 "Qwen3-30B-A3B-Down" 8 1024 2048 2048
|
||||
run_test 20 122 "Qwen3-30B-A3B-Down" 16 1024 2048 2048
|
||||
run_test 21 84 "DeepSeek-V2-Lite-Down" 8 2048 2048 1408
|
||||
run_test 22 52 "DeepSeek-V2-Down" 20 1024 5120 1536
|
||||
run_test 23 182 "MoE-1T-Down" 7 1024 8192 1920
|
||||
run_test 24 34 "DeepSeek-V2-Down" 5 2048 5120 1536
|
||||
run_test 25 76 "DeepSeek-V2-Lite-Down" 4 4096 2048 1408
|
||||
run_test 26 92 "Grok-2-Down" 1 1024 8192 16384
|
||||
run_test 27 65 "DeepSeek-V2-Lite-GateUP" 2 4096 2816 2048
|
||||
run_test 28 68 "DeepSeek-V2-Lite-Down" 2 8192 2048 1408
|
||||
run_test 29 80 "DeepSeek-V2-Lite-Down" 4 16384 2048 1408
|
||||
run_test 30 171 "Mixtral-8x22B-GateUP" 1 1024 32768 6144
|
||||
|
||||
echo "========================================================================================================"
|
||||
echo "All tests completed!"
|
||||
|
||||
Reference in New Issue
Block a user