diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index f00d948f25..683f41228d 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -175,7 +175,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) unit_size, num_experts, topk, - static_cast(moe_buf_size * sizeof(float))}; + static_cast(moe_buf_size * sizeof(float)), + false}; ck_tile::stream_config sc{nullptr, true, diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 109ec1b157..34731576b7 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -153,7 +153,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } } #else - if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0) + if(!a.force_one_shoot && moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0) { return moe_sorting_mp(t, a, s); } diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index b354d1d347..384d43590e 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -36,6 +36,7 @@ struct fused_moe_args ck_tile::index_t topk; // need this? ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size + bool force_one_shoot; }; // This is the public API, will be generated by script diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index f706cf19ff..c9fc1b1222 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -32,7 +32,8 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf a.block_m, // index_t unit_size; a.num_experts, // index_t num_experts; a.topk, // index_t topk; - static_cast(a.num_tokens) * a.stride_token * o_data_bytes // index_t moe_buf_bytes; + static_cast(a.num_tokens) * a.stride_token * o_data_bytes, // index_t moe_buf_bytes; + a.force_one_shoot }; auto t1 = fused_moegemm_traits{t.prec_i, diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index cb93ce8907..fcfc791313 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -413,7 +413,8 @@ bool run(const ck_tile::ArgParser& arg_parser) tokens, experts, topk, - stride}; + stride, + false}; float ave_time = fused_moe( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 56297c2340..b8c8d4c09e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1530,18 +1530,8 @@ struct GridwiseMoeGemm tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(0, 0, block_n_id, 0)), c_element_op}; - - // using BufferType = std::conditional_t< - // std::is_same_v, - // decltype(make_long_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())), - // decltype(make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())) - // >; + auto c_grid_buf = make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - // BufferType c_grid_buf = std::is_same_v ? - // make_long_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) : - // make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 56aadb9861..f18fbde766 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -193,6 +193,7 @@ struct MoeSortingHostArgs index_t num_experts; index_t topk; long_index_t moe_buf_bytes; // byte size of p_moe_buf + bool force_one_shoot; }; template