[CK_TILE] moe_sorting support "local_tokens" feature for EP case (#2335)

* support local_token for hipgraph

* update README

* fix comment

* fix fmoe example
This commit is contained in:
carlushuang
2025-06-18 10:49:43 +08:00
committed by GitHub
parent c7c6a0ccb3
commit a4e1248dba
11 changed files with 495 additions and 162 deletions

View File

@@ -87,7 +87,18 @@ void topid_unique_gen(
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
arg_parser
.insert("t",
"128",
"number of input tokens.\n"
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
.insert(
"local_t",
"-1",
"Number of local input tokens for curent rank.\n"
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
"This feature is to simulate EP case where where each rank has different tokens.\n"
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
.insert("e", "32", "num of experts")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
@@ -131,6 +142,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t local_tokens = arg_parser.get_int("local_t");
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
@@ -169,6 +181,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
// w1 (Down, N size)
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
return false;
}
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_w)
@@ -198,11 +218,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::string(", st:") + std::to_string(stride);
}();
std::cout << "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens;
if(is_local_token)
{
std::cout << "(" << local_tokens << ")";
}
std::cout
<< "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", act:"
<< ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size
<< ", interm:" << intermediate_size << ", tp:" << tp << ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
@@ -377,6 +403,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
moe_sorting_ws.SetZero(); // note, clear here!!!!
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
fused_moe_traits traits{prec_i,
prec_w,
@@ -400,6 +431,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
topk_ids_buf.GetDeviceBuffer(),
@@ -463,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0],
experts,
block_m,
is_local_token ? local_tokens : tokens,
local_expert_masking);
if(activation == 0)
{
@@ -495,6 +528,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0],
experts,
block_m,
is_local_token ? local_tokens : tokens,
local_expert_masking);
// done, preparing GPU buffer
@@ -506,6 +540,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host);
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
// manually clear output buffer for atomic
o_buf.SetZero();
@@ -542,7 +581,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
intermediate_size / tp,
tokens,
is_local_token ? local_tokens : tokens,
experts,
topk,
stride};