mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
add sorting params
This commit is contained in:
@@ -175,7 +175,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
|
||||
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float)),
|
||||
false};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
|
||||
@@ -153,7 +153,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
|
||||
if(!a.force_one_shoot && moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ struct fused_moe_args
|
||||
ck_tile::index_t topk; // need this?
|
||||
|
||||
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
bool force_one_shoot;
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
|
||||
@@ -32,7 +32,8 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
static_cast<ck::long_index_t>(a.num_tokens) * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
|
||||
static_cast<ck::long_index_t>(a.num_tokens) * a.stride_token * o_data_bytes, // index_t moe_buf_bytes;
|
||||
a.force_one_shoot
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
|
||||
@@ -413,7 +413,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
stride,
|
||||
false};
|
||||
float ave_time = fused_moe(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
|
||||
@@ -1530,18 +1530,8 @@ struct GridwiseMoeGemm
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
// using BufferType = std::conditional_t<
|
||||
// std::is_same_v<IndexType, long_index_t>,
|
||||
// decltype(make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())),
|
||||
// decltype(make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()))
|
||||
// >;
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// BufferType c_grid_buf = std::is_same_v<IndexType, long_index_t> ?
|
||||
// make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) :
|
||||
// make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
|
||||
@@ -193,6 +193,7 @@ struct MoeSortingHostArgs
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
bool force_one_shoot;
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
|
||||
Reference in New Issue
Block a user