add sorting params

This commit is contained in:
coderfeli
2025-03-17 07:44:49 +00:00
parent 7dbdff9f9f
commit 04dc9908d0
7 changed files with 10 additions and 15 deletions

View File

@@ -175,7 +175,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
unit_size,
num_experts,
topk,
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float)),
false};
ck_tile::stream_config sc{nullptr,
true,

View File

@@ -153,7 +153,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
}
}
#else
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
if(!a.force_one_shoot && moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0)
{
return moe_sorting_mp(t, a, s);
}

View File

@@ -36,6 +36,7 @@ struct fused_moe_args
ck_tile::index_t topk; // need this?
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
bool force_one_shoot;
};
// This is the public API, will be generated by script

View File

@@ -32,7 +32,8 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
static_cast<ck::long_index_t>(a.num_tokens) * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
static_cast<ck::long_index_t>(a.num_tokens) * a.stride_token * o_data_bytes, // index_t moe_buf_bytes;
a.force_one_shoot
};
auto t1 = fused_moegemm_traits{t.prec_i,

View File

@@ -413,7 +413,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens,
experts,
topk,
stride};
stride,
false};
float ave_time = fused_moe(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});

View File

@@ -1530,18 +1530,8 @@ struct GridwiseMoeGemm
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
// using BufferType = std::conditional_t<
// std::is_same_v<IndexType, long_index_t>,
// decltype(make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())),
// decltype(make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()))
// >;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// BufferType c_grid_buf = std::is_same_v<IndexType, long_index_t> ?
// make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) :
// make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,

View File

@@ -193,6 +193,7 @@ struct MoeSortingHostArgs
index_t num_experts;
index_t topk;
long_index_t moe_buf_bytes; // byte size of p_moe_buf
bool force_one_shoot;
};
template <typename Problem_>