[CK_TILE] Fix mock token id, support g1u1/g1u0 through same inline code block (#1808)

* fix mock token id

* prepare host for g1u1

* reformat inline-asm

* restructure uk_0

* restructure gate_up

* done

* change default to init=1

* update readme

* fix a bug in interleave pipeline

* rcp for silu

[ROCm/composable_kernel commit: 1ff50e78c6]
This commit is contained in:
carlushuang
2025-01-16 17:51:10 +08:00
committed by GitHub
parent 11eab9ca17
commit 21264b4e60
20 changed files with 1927 additions and 1413 deletions

View File

@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
.insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
.insert("balance",
"0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init",
"2",
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"1",
"init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)")
.insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter")
@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm");
ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0)
stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i");
@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::string(", st:") + std::to_string(stride);
}();
std::cout << "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
std::cout
<< "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq,
prec_kw,
block_m,
activation,
gate_only,
fused_quant};
@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(),
block_m,
hidden_size,
shared_intermediate_size_0,
intermediate_size / tp,
tokens,
experts,
topk,
@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if(do_validation)
{
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0],
experts,
block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
if(activation == 0)
{
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
}
else
{
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
}
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq,
prec_kw,
block_m,
activation,
gate_only,
fused_quant};
@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
shared_intermediate_size_0,
intermediate_size / tp,
tokens,
experts,
topk,
@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
if(activation == 0)
{
CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
}
else
{
CPU_FUSED_MOE(ck_tile::element_wise::Silu);
}
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");