mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] moe sorting optimize local_token (#2469)
* fix bug in loops that need use local tokens to compute
* support extra chain local_token
* update
* update
* refine some main
* update
* support dispatch_policy
* fix 15 example
[ROCm/composable_kernel commit: cfe211cc60]
This commit is contained in:
@@ -35,7 +35,20 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
|
||||
.insert(
|
||||
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
|
||||
#else
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
#endif
|
||||
.insert("ci",
|
||||
"1",
|
||||
"clear workspace inside API or not(if \"0\", require manually clear outside)")
|
||||
.insert(
|
||||
"dispatch",
|
||||
"0",
|
||||
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
|
||||
.insert("local_eid",
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
@@ -88,10 +101,17 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int unit_size = args.get_int("unit");
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
|
||||
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
|
||||
#else
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
#endif
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
bool clear_inside = args.get_int("ci") != 0;
|
||||
int dispatch_policy = args.get_int("dispatch");
|
||||
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
@@ -149,11 +169,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
|
||||
// for simplicity, below buffer allocate 2 dword
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_host(
|
||||
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
|
||||
moe_buf_elem_bytes});
|
||||
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#else
|
||||
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
|
||||
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#endif
|
||||
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
|
||||
#else
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
#endif
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
@@ -176,7 +211,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -184,29 +219,31 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
ck_tile::index_t workspace_size =
|
||||
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
if(workspace_size != 0 && clear_inside == false)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
moe_sorting_trait trait{
|
||||
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))};
|
||||
moe_sorting_args karg
|
||||
{
|
||||
topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size,
|
||||
num_experts, topk,
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_interm_dim, moe_buf_elem_bytes
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
|
||||
#endif
|
||||
};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
@@ -219,7 +256,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
#if 0
|
||||
{
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
moe_sorting_ws.FromDevice(ws_host.data());
|
||||
|
||||
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
|
||||
@@ -268,7 +305,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
|
||||
printf("[%s|%s|%s|%d]tokens:%d",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
|
||||
dispatch_policy,
|
||||
tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
printf("(%d)", local_tokens);
|
||||
@@ -280,6 +322,19 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
|
||||
}
|
||||
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
printf("moe_buf:%lu(%d,%d), ",
|
||||
static_cast<uint64_t>(moe_buf_bytes),
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes);
|
||||
#else
|
||||
|
||||
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
else
|
||||
@@ -294,7 +349,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_weights_dev.FromDevice(sorted_weights_host.data());
|
||||
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
|
||||
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.FromDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -340,6 +395,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
// if(is_local_token)
|
||||
{
|
||||
auto t_ = is_local_token ? local_tokens : tokens;
|
||||
bool _f = t_ == sorted_id_cnt_host.mData[1];
|
||||
rtn &= _f;
|
||||
if(!_f)
|
||||
{
|
||||
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -347,9 +412,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
rtn = false;
|
||||
}
|
||||
|
||||
if(moe_buf_size)
|
||||
if(moe_buf_bytes)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
|
||||
#else
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
#endif
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
@@ -293,6 +293,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
@@ -302,6 +303,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
@@ -314,6 +316,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
@@ -323,6 +326,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
@@ -330,6 +334,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
|
||||
[&]() { \
|
||||
using problem_ = \
|
||||
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
|
||||
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
@@ -338,6 +353,22 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
|
||||
if(t.clear_workspace_inside_api)
|
||||
{
|
||||
if(is_local_token)
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
@@ -345,6 +376,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
@@ -354,6 +386,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
@@ -405,7 +438,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
return -1;
|
||||
}
|
||||
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
}
|
||||
|
||||
@@ -10,8 +10,14 @@
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of
|
||||
// it inside API)
|
||||
int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 -
|
||||
// always use mp kernel NOTE: moe_sorting_get_workspace_size() need use
|
||||
// same dispatch_policy value. it will be undefined behavior if ppl using
|
||||
// different value when get ws and call the kernel
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
@@ -22,6 +28,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
// if return non zero, means need workspace, you need to allocate a GPU buffer
|
||||
// and set to moe_sorting_args.p_ws
|
||||
// NOTE: workspace size are required to clear zero before use the API
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy);
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# #!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_moe_sorting
|
||||
MOE_BUF="12"
|
||||
|
||||
if [ "x$MOE_BUF" = "x1" ] ; then
|
||||
$EXE -t=80 -e=17 -moe_buf_size=16
|
||||
$EXE -t=111 -e=117 -moe_buf_size=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_size=1024
|
||||
@@ -42,3 +44,46 @@ $EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
|
||||
else
|
||||
$EXE -t=80 -e=17 -moe_buf_interm_dim=16 -moe_buf_elem_bytes=4
|
||||
$EXE -t=111 -e=117 -moe_buf_interm_dim=4 -moe_buf_elem_bytes=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_interm_dim=1024 -moe_buf_elem_bytes=1
|
||||
$EXE -t=99 -e=120 -moe_buf_interm_dim=10244 -moe_buf_elem_bytes=2
|
||||
$EXE -t=175 -e=64 -k=8
|
||||
$EXE -t=65 -e=8 -k=2
|
||||
$EXE -t=1 -e=25
|
||||
$EXE -t=31 -e=19 -k=15
|
||||
$EXE -t=81 -e=37 -k=7
|
||||
$EXE -t=23 -e=1 -k=1
|
||||
$EXE -t=127 -e=99 -k=19
|
||||
$EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
$EXE -t=11 -e=256 -k=5
|
||||
$EXE -t=64 -e=455 -k=8
|
||||
$EXE -t=777 -e=802 -k=99
|
||||
$EXE -t=4097 -e=906 -k=51
|
||||
$EXE -t=128 -e=32 -k=5 -local_t=6 -moe_buf_interm_dim=262144
|
||||
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
$EXE -t=128 -e=128 -k=6 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=32 -k=5 -local_t=11 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -local_t=12 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=256 -k=5 -local_t=13 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -local_t=8 -moe_buf_interm_dim=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -local_t=4 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=4
|
||||
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
|
||||
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
|
||||
$EXE -t=99 -local_t=93 -e=121 -local_t=4 -moe_buf_interm_dim=10244
|
||||
$EXE -t=536 -local_t=345 -e=802 -k=99
|
||||
$EXE -t=331 -local_t=39 -e=83 -k=33
|
||||
$EXE -t=765 -local_t=654 -e=783 -k=8
|
||||
$EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -local_t=2 -moe_buf_interm_dim=133940 -moe_buf_elem_bytes=1
|
||||
|
||||
fi
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
@@ -24,23 +25,28 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
|
||||
o_data_bytes // index_t moe_buf_bytes;
|
||||
auto a0 = fused_moesorting_args
|
||||
{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
a.stride_token, o_data_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) *
|
||||
a.stride_token* o_data_bytes // index_t moe_buf_bytes;
|
||||
#endif
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
|
||||
@@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
@@ -23,6 +23,11 @@ namespace ck_tile {
|
||||
#define MOE_SORTING_FUSE_MP_01 0
|
||||
#endif
|
||||
|
||||
// weather use 2d buffer indexing for fmoe ws or 1d
|
||||
#ifndef MOE_SORTING_FMOE_2D_BUF
|
||||
#define MOE_SORTING_FMOE_2D_BUF 1
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
@@ -171,7 +176,7 @@ struct MoeSortingHostArgs
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
void* p_total_tokens_post_pad; // [2], [0]:outputed tokens_post_padded, [1]:actual tokens on current rank (local_tokens or tokens)
|
||||
// we fused the setzero of output of fused-moe buffer
|
||||
// set this pointer to nullptr will skip this operation
|
||||
void* p_moe_buf;
|
||||
@@ -182,7 +187,18 @@ struct MoeSortingHostArgs
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
// NOTE:
|
||||
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
|
||||
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
|
||||
// we fuse this clearing inside sorting kernel
|
||||
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -197,6 +213,9 @@ struct MoeSortingKernel
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
@@ -210,8 +229,12 @@ struct MoeSortingKernel
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes;
|
||||
|
||||
#endif
|
||||
index_t tokens_per_thread;
|
||||
index_t smem_rows;
|
||||
mdiv unit_size_mdiv;
|
||||
@@ -220,10 +243,27 @@ struct MoeSortingKernel
|
||||
// mdiv sub_tokens_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
(void)h;
|
||||
return get_num_cu() * OCCUPANCY;
|
||||
#else
|
||||
// TODO: assume num-experts not too much
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
|
||||
@@ -263,7 +303,12 @@ struct MoeSortingKernel
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
// NOTE: tokens could from p_local_tokens, so here this variable is useless
|
||||
@@ -431,6 +476,24 @@ struct MoeSortingKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
moe_buf_set_zero_kernel_2d(void* buf, index_t row, index_t col, index_t elem_bytes) const
|
||||
{
|
||||
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
|
||||
const long_index_t total_bytes = total_pixels * elem_bytes;
|
||||
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += (gridDim.x - 1) * BLOCK_SIZE)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
|
||||
const WeightType* __restrict__ weights,
|
||||
index_t* p_sorted_token_ids,
|
||||
@@ -863,7 +926,8 @@ struct MoeSortingKernel
|
||||
}
|
||||
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
p_total_tokens_post_pad[1] = tokens;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -1005,20 +1069,6 @@ struct MoeSortingKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)numel;
|
||||
index_t tokens_ = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
@@ -1029,6 +1079,25 @@ struct MoeSortingKernel
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_set_zero_kernel_2d(
|
||||
kargs.p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes);
|
||||
#else
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
return moe_align_block_size_kernel_ex(
|
||||
static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
@@ -1045,6 +1114,7 @@ struct MoeSortingKernel
|
||||
kargs.smem_rows,
|
||||
smem);
|
||||
#else
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
@@ -1066,6 +1136,8 @@ namespace impl {
|
||||
// [expert, padded_tokens]
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
|
||||
{
|
||||
// Pad to multiply of 32. This can make sure even if the mesh is in 8bit,
|
||||
// we can still use dwordx4 load/store
|
||||
constexpr index_t chunk = 32;
|
||||
return (tokens + chunk - 1) / chunk * chunk;
|
||||
};
|
||||
@@ -1261,6 +1333,24 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BLOCK_SIZE = 256>
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
|
||||
void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
|
||||
{
|
||||
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
|
||||
const long_index_t total_bytes = total_pixels * elem_bytes;
|
||||
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// TODO: tokens could be from
|
||||
@@ -1292,12 +1382,29 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe
|
||||
}
|
||||
|
||||
// return size in byte
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_)
|
||||
// dispatch_policy: 0-automatically pick up kerel. 1-always use single kernel, 2-always use mp
|
||||
// kernel
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_,
|
||||
int num_experts_,
|
||||
int topk_,
|
||||
int dispatch_policy_)
|
||||
{
|
||||
#if 1
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
// return 0;
|
||||
if(dispatch_policy_ == 0)
|
||||
{
|
||||
return 0;
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
|
||||
}
|
||||
}
|
||||
else if(dispatch_policy_ == 1)
|
||||
{
|
||||
return 0; // always use single kernel
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1308,6 +1415,98 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem_>
|
||||
struct MoeSortingClearWorkspaceKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr index_t BLOCK_SIZE = Problem::BlockSize;
|
||||
static constexpr index_t OCCUPANCY = Problem::Occu;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
index_t mesh_byte_size;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk);
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
index_t row_size = mesh_stride; // impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
index_t pixels = kargs.num_experts * row_size;
|
||||
index_t total_bytes = pixels * kargs.mesh_byte_size;
|
||||
index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
{
|
||||
p_expert_mesh[i] = zero_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// below kernel is multi-phase implementation for large token and/or expert case
|
||||
|
||||
// write into a buffer to record the token cnt
|
||||
@@ -1435,6 +1634,16 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
@@ -1449,12 +1658,11 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1479,6 +1687,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
@@ -1488,6 +1697,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -1511,12 +1721,9 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
int eid = blockIdx.x;
|
||||
|
||||
int eid = blockIdx.x;
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
|
||||
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
@@ -1524,7 +1731,32 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0; // will not use if not LocalToken
|
||||
}
|
||||
}();
|
||||
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
|
||||
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -1538,7 +1770,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (kargs.mesh_stride / index_pack))
|
||||
if(position < (mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
index_t local_sum = 0;
|
||||
static_for<0, index_pack, 1>{}(
|
||||
@@ -1835,7 +2067,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_total_tokens_post_pad; // [2]
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
@@ -1863,15 +2095,36 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
@@ -1888,11 +2141,21 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
kargs.tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
blockIdx.x - 1,
|
||||
gridDim.x - 1);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - 1);
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -2223,7 +2486,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_total_tokens_post_pad; // [2]
|
||||
void* p_sorted_expert_ids;
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
@@ -2235,7 +2498,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv unit_size_mdiv;
|
||||
mdiv topk_mdiv;
|
||||
long_index_t moe_buf_bytes;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
// NOTE:
|
||||
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
|
||||
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
|
||||
// we fuse this clearing inside sorting kernel
|
||||
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
#endif
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
@@ -2262,16 +2535,37 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
@@ -2287,13 +2581,34 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
// reduce single pixel within a wave
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
blockIdx.x - kargs.num_experts,
|
||||
gridDim.x - kargs.num_experts);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - kargs.num_experts);
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
extern __shared__ char smem[];
|
||||
@@ -2428,13 +2743,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
p_total_tokens_post_pad[0] = total_tokens_post_pad;
|
||||
p_total_tokens_post_pad[1] = tokens;
|
||||
}
|
||||
p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
@@ -2463,14 +2780,14 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
index_t tokens = [&]() {
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -2478,7 +2795,8 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
using d_t = ext_vector_t<index_t, index_pack>;
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
int prev_cumsum = 0;
|
||||
|
||||
for(int i = 0; i < loops; i++)
|
||||
@@ -2487,8 +2805,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
|
||||
eid * kargs.mesh_stride)[i_token_pack];
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh + eid * mesh_stride)[i_token_pack];
|
||||
}
|
||||
|
||||
r_t x_r;
|
||||
|
||||
@@ -73,4 +73,12 @@ struct MoeSortingProblemMp
|
||||
SubTokenTile == 8 || SubTokenTile == 16);
|
||||
};
|
||||
|
||||
template <bool LocalToken_, index_t BlockSize_ = 1024, index_t Occu_ = 1>
|
||||
struct MoeSortingClearWorkspaceProblem
|
||||
{
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t Occu = Occu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user