mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_TILE] moe sorting optimize local_token (#2469)
* fix bug in loops that need use local tokens to compute * support extra chain local_token * update * update * refine some main * update * support dispatch_policy * fix 15 example
This commit is contained in:
@@ -35,7 +35,20 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
|
||||
.insert(
|
||||
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
|
||||
#else
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
#endif
|
||||
.insert("ci",
|
||||
"1",
|
||||
"clear workspace inside API or not(if \"0\", require manually clear outside)")
|
||||
.insert(
|
||||
"dispatch",
|
||||
"0",
|
||||
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
|
||||
.insert("local_eid",
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
@@ -88,10 +101,17 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int unit_size = args.get_int("unit");
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
|
||||
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
|
||||
#else
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
#endif
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
bool clear_inside = args.get_int("ci") != 0;
|
||||
int dispatch_policy = args.get_int("dispatch");
|
||||
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
@@ -149,11 +169,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
|
||||
// for simplicity, below buffer allocate 2 dword
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_host(
|
||||
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
|
||||
moe_buf_elem_bytes});
|
||||
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#else
|
||||
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
|
||||
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#endif
|
||||
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
|
||||
#else
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
#endif
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
@@ -176,7 +211,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -184,29 +219,31 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
ck_tile::index_t workspace_size =
|
||||
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
if(workspace_size != 0 && clear_inside == false)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
moe_sorting_trait trait{
|
||||
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))};
|
||||
moe_sorting_args karg
|
||||
{
|
||||
topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size,
|
||||
num_experts, topk,
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_interm_dim, moe_buf_elem_bytes
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
|
||||
#endif
|
||||
};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
@@ -219,7 +256,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
#if 0
|
||||
{
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
moe_sorting_ws.FromDevice(ws_host.data());
|
||||
|
||||
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
|
||||
@@ -268,7 +305,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
|
||||
printf("[%s|%s|%s|%d]tokens:%d",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
|
||||
dispatch_policy,
|
||||
tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
printf("(%d)", local_tokens);
|
||||
@@ -280,6 +322,19 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
|
||||
}
|
||||
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
printf("moe_buf:%lu(%d,%d), ",
|
||||
static_cast<uint64_t>(moe_buf_bytes),
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes);
|
||||
#else
|
||||
|
||||
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
else
|
||||
@@ -294,7 +349,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_weights_dev.FromDevice(sorted_weights_host.data());
|
||||
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
|
||||
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.FromDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -340,6 +395,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
// if(is_local_token)
|
||||
{
|
||||
auto t_ = is_local_token ? local_tokens : tokens;
|
||||
bool _f = t_ == sorted_id_cnt_host.mData[1];
|
||||
rtn &= _f;
|
||||
if(!_f)
|
||||
{
|
||||
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -347,9 +412,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
rtn = false;
|
||||
}
|
||||
|
||||
if(moe_buf_size)
|
||||
if(moe_buf_bytes)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
|
||||
#else
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
#endif
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
@@ -293,6 +293,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
@@ -302,6 +303,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
@@ -314,6 +316,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
@@ -323,6 +326,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
@@ -330,6 +334,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
|
||||
[&]() { \
|
||||
using problem_ = \
|
||||
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
|
||||
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
@@ -338,6 +353,22 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
|
||||
if(t.clear_workspace_inside_api)
|
||||
{
|
||||
if(is_local_token)
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
@@ -345,6 +376,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
@@ -354,6 +386,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
@@ -405,7 +438,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
return -1;
|
||||
}
|
||||
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
}
|
||||
|
||||
@@ -10,8 +10,14 @@
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of
|
||||
// it inside API)
|
||||
int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 -
|
||||
// always use mp kernel NOTE: moe_sorting_get_workspace_size() need use
|
||||
// same dispatch_policy value. it will be undefined behavior if ppl using
|
||||
// different value when get ws and call the kernel
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
@@ -22,6 +28,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
// if return non zero, means need workspace, you need to allocate a GPU buffer
|
||||
// and set to moe_sorting_args.p_ws
|
||||
// NOTE: workspace size are required to clear zero before use the API
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy);
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# #!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_moe_sorting
|
||||
MOE_BUF="12"
|
||||
|
||||
if [ "x$MOE_BUF" = "x1" ] ; then
|
||||
$EXE -t=80 -e=17 -moe_buf_size=16
|
||||
$EXE -t=111 -e=117 -moe_buf_size=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_size=1024
|
||||
@@ -42,3 +44,46 @@ $EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
|
||||
else
|
||||
$EXE -t=80 -e=17 -moe_buf_interm_dim=16 -moe_buf_elem_bytes=4
|
||||
$EXE -t=111 -e=117 -moe_buf_interm_dim=4 -moe_buf_elem_bytes=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_interm_dim=1024 -moe_buf_elem_bytes=1
|
||||
$EXE -t=99 -e=120 -moe_buf_interm_dim=10244 -moe_buf_elem_bytes=2
|
||||
$EXE -t=175 -e=64 -k=8
|
||||
$EXE -t=65 -e=8 -k=2
|
||||
$EXE -t=1 -e=25
|
||||
$EXE -t=31 -e=19 -k=15
|
||||
$EXE -t=81 -e=37 -k=7
|
||||
$EXE -t=23 -e=1 -k=1
|
||||
$EXE -t=127 -e=99 -k=19
|
||||
$EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
$EXE -t=11 -e=256 -k=5
|
||||
$EXE -t=64 -e=455 -k=8
|
||||
$EXE -t=777 -e=802 -k=99
|
||||
$EXE -t=4097 -e=906 -k=51
|
||||
$EXE -t=128 -e=32 -k=5 -local_t=6 -moe_buf_interm_dim=262144
|
||||
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
$EXE -t=128 -e=128 -k=6 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=32 -k=5 -local_t=11 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -local_t=12 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=256 -k=5 -local_t=13 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -local_t=8 -moe_buf_interm_dim=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -local_t=4 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=4
|
||||
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
|
||||
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
|
||||
$EXE -t=99 -local_t=93 -e=121 -local_t=4 -moe_buf_interm_dim=10244
|
||||
$EXE -t=536 -local_t=345 -e=802 -k=99
|
||||
$EXE -t=331 -local_t=39 -e=83 -k=33
|
||||
$EXE -t=765 -local_t=654 -e=783 -k=8
|
||||
$EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -local_t=2 -moe_buf_interm_dim=133940 -moe_buf_elem_bytes=1
|
||||
|
||||
fi
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
@@ -24,23 +25,28 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
|
||||
o_data_bytes // index_t moe_buf_bytes;
|
||||
auto a0 = fused_moesorting_args
|
||||
{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
a.stride_token, o_data_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) *
|
||||
a.stride_token* o_data_bytes // index_t moe_buf_bytes;
|
||||
#endif
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
|
||||
@@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
Reference in New Issue
Block a user