From eec0fed606827cd53c99df96c833c96fd9f935d6 Mon Sep 17 00:00:00 2001 From: dummycoderfe Date: Sat, 9 Nov 2024 17:57:27 +0800 Subject: [PATCH] Ck tile/moe sorting (#1624) * add moe_sorting & check ok * fix comments & typo * Run remod.py under include/ck_tile & example/ck_tile directories * format codes * fix output ci check bug * fix moe sorting readme and error commit file * use magiv div to accelerate compute * add an loop unroll for moe lds ops * add extblocksnel to set zeros for moebufs * [Ck_tile] moe set zero run ok, add size check and fix ref check * [Ck_tile]fix moe_sorting fuse set_zero remod * [Ck_tile] change name style, fix zero buffer size err, change folder * [Ck_tile] moe_sorting: fix name style * [Ck_tile] moe_sorting, remove useless params in traits * [Ck_tile] change outputtile cnt * unit_size; change output buf alloc --------- Co-authored-by: dummycoderfe Co-authored-by: Po Yen, Chen Co-authored-by: carlushuang [ROCm/composable_kernel commit: bec6fbc65fe766ab23fe563675703defdb0dd2be] --- example/ck_tile/13_moe_sorting/CMakeLists.txt | 8 + example/ck_tile/13_moe_sorting/README.md | 27 ++ .../ck_tile/13_moe_sorting/moe_sorting.cpp | 223 +++++++++++++++++ .../13_moe_sorting/moe_sorting_api.cpp | 73 ++++++ .../13_moe_sorting/moe_sorting_api.hpp | 20 ++ .../13_moe_sorting/script/smoke_test.sh | 19 ++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/host.hpp | 1 + .../host/reference/reference_moe_sorting.hpp | 78 ++++++ .../fused_moe/kernel/moe_sorting_kernel.hpp | 232 ++++++++++++++++++ .../pipeline/moe_sorting_pipeline.hpp | 39 +++ .../fused_moe/pipeline/moe_sorting_policy.hpp | 15 ++ .../pipeline/moe_sorting_problem.hpp | 23 ++ include/ck_tile/ops/moe_sorting.hpp | 11 + 14 files changed, 770 insertions(+) create mode 100644 example/ck_tile/13_moe_sorting/CMakeLists.txt create mode 100644 example/ck_tile/13_moe_sorting/README.md create mode 100644 example/ck_tile/13_moe_sorting/moe_sorting.cpp create mode 100644 example/ck_tile/13_moe_sorting/moe_sorting_api.cpp create mode 100644 example/ck_tile/13_moe_sorting/moe_sorting_api.hpp create mode 100644 example/ck_tile/13_moe_sorting/script/smoke_test.sh create mode 100644 include/ck_tile/host/reference/reference_moe_sorting.hpp create mode 100644 include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp create mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp create mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp create mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp create mode 100644 include/ck_tile/ops/moe_sorting.hpp diff --git a/example/ck_tile/13_moe_sorting/CMakeLists.txt b/example/ck_tile/13_moe_sorting/CMakeLists.txt new file mode 100644 index 0000000000..09f3e4ac4e --- /dev/null +++ b/example/ck_tile/13_moe_sorting/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp) +target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS}) diff --git a/example/ck_tile/13_moe_sorting/README.md b/example/ck_tile/13_moe_sorting/README.md new file mode 100644 index 0000000000..7b6792dd95 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/README.md @@ -0,0 +1,27 @@ +# moe-sorting + +This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_moe_sorting -j +``` +This will result in an executable `build/bin/tile_example_moe_sorting` + +## example +``` +args: + -v weather do CPU validation or not (default:1) + -pr_i index data type. (currently only fp32 supported now) (default:int32) + -pr_w output weight data type(currently only fp32 supported now) (default:fp32) + -t number of input tokens (default:32) + -e number of experts (default:8) + -k topk (default:2) + -st_i row stride of input, -1 means same as experts (default:-1) + -seed seed to be used, -1 means random every time (default:-1) + -kname when set to 1 it will print kernel name (default:0) + +``` diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp new file mode 100644 index 0000000000..d2c4df1058 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "moe_sorting_api.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("pr_i", "int32", "index data type. (currently only int32 supported now)") + .insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)") + .insert("t", "128", "number of input tokens") + .insert("e", "8", "number of num_experts") + .insert("k", "4", "topk") + .insert("unit", "32", "unit_size") + .insert("moe_buf_size", "0", "moe_buf_size") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "when set to 1 it will print kernel name") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } +} + +template +bool test_moe_sorting(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + int tokens = args.get_int("t"); + int num_experts = args.get_int("e"); + int topk = args.get_int("k"); + int seed = args.get_int("seed"); + int unit_size = args.get_int("unit"); + int moe_buf_size = args.get_int("moe_buf_size"); + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + int max_output_ids = + ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); + + if(seed < 0) + { + seed = std::time(nullptr); + } + + if(topk > num_experts) + { + printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n", + topk, + num_experts); + return false; + } + + // tokens already considered batch size + ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor sorted_ids_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_host({max_output_ids / unit_size}, {1}); + ck_tile::HostTensor sorted_id_cnt_host({1}, {1}); + ck_tile::HostTensor moe_buf_host({moe_buf_size}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(weights_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); + topid_unique_gen(topk_ids_host.mData, tokens, topk, num_experts, seed); + + ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_expert_ids_dev( + sorted_expert_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + + topk_ids_dev.ToDevice(topk_ids_host.data()); + weights_dev.ToDevice(weights_host.data()); + if(moe_buf_size > 0) + { + moe_buf_dev.ToDevice(moe_buf_host.data()); + } + + moe_sorting_trait trait{index_prec, weight_prec}; + + moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), + weights_dev.GetDeviceBuffer(), + sorted_ids_dev.GetDeviceBuffer(), + sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), + sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + num_experts, + topk, + static_cast(moe_buf_size * sizeof(float))}; + + ck_tile::stream_config sc{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + warmup, + repeat}; + auto ms = moe_sorting(trait, karg, sc); + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + index_prec.c_str(), + weight_prec.c_str(), + tokens, + num_experts, + topk, + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + if(ms < 0) + { + return false; + } + + sorted_ids_dev.FromDevice(sorted_ids_host.data()); + sorted_weights_dev.FromDevice(sorted_weights_host.data()); + sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data()); + sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data()); + if(moe_buf_size > 0) + { + moe_buf_dev.FromDevice(moe_buf_host.data()); + } + + bool rtn = true; + if(validate) + { + ck_tile::HostTensor sorted_ids_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_ref({max_output_ids / unit_size}, {1}); + + int32_t ref_total_tokens_post_pad = 0; + ck_tile::reference_moe_sorting(topk_ids_host, + weights_host, + sorted_ids_ref, + sorted_weights_ref, + sorted_expert_ids_ref, + ref_total_tokens_post_pad, + num_experts, + unit_size); + rtn &= ck_tile::check_err( + sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); + rtn &= ck_tile::check_err(sorted_weights_host, + sorted_weights_ref, + std::string("OUT Error: Incorrect w!"), + 1e-6, + 1e-6); + rtn &= ck_tile::check_err(sorted_expert_ids_host, + sorted_expert_ids_ref, + std::string("OUT Error: Incorrect eid!"), + 1e-6, + 1e-6); + if(moe_buf_size) + { + ck_tile::HostTensor moe_buf_ref({moe_buf_size}); + rtn &= ck_tile::check_err( + moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); + } + rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + } + + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + + bool r = true; + if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0) + { + r &= test_moe_sorting(args); + } + return r ? 0 : -1; +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp new file mode 100644 index 0000000000..25e99c5306 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_sorting_api.hpp" + +#define MOE_SORTING_DISPATCH(unroll_num_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + using ms_problem = ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + if(a.num_experts > 127) + { + printf("lds size exceed, only support experts <127 \n"); + return -1; + } + if(a.moe_buf_bytes % 16) + { + printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes); + return -1; + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64); + switch(smem_io_unroll_num) + { + case(1): { + MOE_SORTING_DISPATCH(1); + } + case(2): { + MOE_SORTING_DISPATCH(2); + } + case(3): { + MOE_SORTING_DISPATCH(3); + } + case(5): { + MOE_SORTING_DISPATCH(5); + } + case(6): { + MOE_SORTING_DISPATCH(6); + } + case(7): { + MOE_SORTING_DISPATCH(7); + } + case(8): { + MOE_SORTING_DISPATCH(8); + } + case(9): { + MOE_SORTING_DISPATCH(9); + } + case(10): { + MOE_SORTING_DISPATCH(10); + } + case(11): { + MOE_SORTING_DISPATCH(11); + } + default: { + MOE_SORTING_DISPATCH(4); + } + } + } + return -1; +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp new file mode 100644 index 0000000000..91b54932ce --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/moe_sorting.hpp" + +struct moe_sorting_trait +{ + std::string index_type; + std::string weight_type; // currently always float +}; + +struct moe_sorting_args : public ck_tile::MoeSortingHostArgs +{ +}; + +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh new file mode 100644 index 0000000000..1fc5eafcb0 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -0,0 +1,19 @@ +# #!/bin/sh + +EXE=./build/bin/tile_example_moe_sorting + +$EXE -t=80 -e=17 -moe_buf_size=16 +$EXE -t=111 -e=117 -moe_buf_size=4 +$EXE -t=1000 -e=55 -moe_buf_size=1024 +$EXE -t=99 -e=120 -moe_buf_size=10244 +$EXE -t=175 -e=64 -k=8 +$EXE -t=65 -e=8 -k=2 +$EXE -t=1 -e=25 +$EXE -t=31 -e=19 -k=15 +$EXE -t=81 -e=37 -k=7 +$EXE -t=23 -e=1 -k=1 +$EXE -t=127 -e=99 -k=19 +$EXE -t=71 -e=11 -k=11 +$EXE -t=1 -e=1 -k=1 +$EXE -t=99 -e=2 -k=1 +$EXE -t=333 -e=99 -k=13 \ No newline at end of file diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 9dd9a6ca3c..15db0f46c4 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax) add_subdirectory(10_rmsnorm2d) add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(12_smoothquant) +add_subdirectory(13_moe_sorting) diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index c0ab13ce3d..2e96009ace 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -23,6 +23,7 @@ #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" +#include "ck_tile/host/reference/reference_moe_sorting.hpp" #include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp new file mode 100644 index 0000000000..c8eb7edb55 --- /dev/null +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, + const HostTensor& weights, + HostTensor& p_sorted_token_ids, + HostTensor& sorted_weight, + HostTensor& sorted_expert_ids, + index_t& unit_cnt, + const index_t experts, + const index_t unit_size) +{ + const index_t num_token = topk_ids.mDesc.get_lengths()[0]; + const index_t topk = topk_ids.mDesc.get_lengths()[1]; + std::vector> expert_tokens(experts, + std::vector(unit_size, num_token)); + std::vector> expert_token_weights( + experts, std::vector(unit_size, 0)); + std::vector expert_slices(experts, 1); + std::vector expert_slice_idxs(experts, 0); + + for(index_t t = 0; t < num_token; t++) + { + for(index_t k = 0; k < topk; k++) + { + IndexType e = topk_ids(t, k); + WeightType w = weights(t, k); + index_t idx = expert_slice_idxs[e]; + if(idx > expert_slices[e] * unit_size - 1) + { + expert_slices[e]++; + index_t new_size = expert_slices[e] * unit_size; + expert_tokens[e].resize(new_size); + expert_token_weights[e].resize(new_size); + for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++) + { + expert_tokens[e][i] = num_token; + expert_token_weights[e][i] = 0; + } + } + + expert_tokens[e][idx] = t; + expert_token_weights[e][idx] = w; + expert_slice_idxs[e]++; + } + } + + IndexType* out_tokens = p_sorted_token_ids.data(); + WeightType* out_weights = sorted_weight.data(); + IndexType* out_expert_id = sorted_expert_ids.data(); + for(index_t e = 0; e < experts; e++) + { + memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); + out_tokens += expert_slices[e] * unit_size; + memcpy(out_weights, + expert_token_weights[e].data(), + sizeof(WeightType) * expert_slices[e] * unit_size); + out_weights += expert_slices[e] * unit_size; + + for(index_t s = 0; s < expert_slices[e]; s++) + { + out_expert_id[s] = e; + unit_cnt++; + } + out_expert_id += expert_slices[e]; + } + unit_cnt *= unit_size; + return; +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp new file mode 100644 index 0000000000..1c6acec70e --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include + +namespace ck_tile { + +struct MoeSortingHostArgs +{ + const void* p_topk_ids; + const void* p_weights; + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_sorted_expert_ids; + void* p_total_tokens_post_pad; + void* p_moe_buf; + index_t tokens; + index_t unit_size; + index_t num_experts; + index_t topk; + index_t moe_buf_bytes; +}; + +template +struct MoeSortingKernel +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; + const void* p_weights; + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_sorted_expert_ids; + void* p_total_tokens_post_pad; + void* p_moe_buf; + index_t tokens; + index_t num_experts; + index_t moe_buf_bytes; + + index_t tokens_per_thread; + mdiv unit_size_mdiv; + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + // TODO: assume num-experts not too much + return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16)); + } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) + { + return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); + } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) + { + const auto blocks = BlockSize(h); + return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t); + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_weights = h.p_weights; + k.p_sorted_token_ids = h.p_sorted_token_ids; + k.p_sorted_weights = h.p_sorted_weights; + k.p_sorted_expert_ids = h.p_sorted_expert_ids; + k.p_moe_buf = h.p_moe_buf; + k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.moe_buf_bytes = h.moe_buf_bytes; + + const auto blocks = BlockSize(h); + k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); + k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; + k.topk_mdiv = mdiv{static_cast(h.topk)}; + return k; + } + + CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const + { + return row * total_col + col; + } + + CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const + { + const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x; + if(offset < buf_bytes / 16) + { + buf[offset] = uint8x16_t{0}; + } + } + + CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id, + const WeightType* __restrict__ weights, + index_t* p_sorted_token_ids, + WeightType* p_sorted_weights, + index_t* p_sorted_expert_ids, + index_t* p_total_tokens_post_pad, + const index_t num_experts, + const index_t tokens_per_thread, + const index_t numel, + const mdiv unit_size_mdiv, + const mdiv topk_mdiv, + void* smem) const + { + const index_t tid = static_cast(threadIdx.x); + const index_t start_idx = tid * tokens_per_thread; + + index_t* shared_mem = reinterpret_cast(smem); + + index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) + index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1) + for(int i = 0; i < num_experts; ++i) + { + tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0; + } +#pragma unroll Problem_::InternalLoadUnroll + for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) + { + ++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])]; + } + __syncthreads(); + + if(tid < num_experts) + { + tokens_cnts[calc_index(num_experts, 0, tid)] = 0; + for(int i = 1; i <= static_cast(blockDim.x); ++i) + { + tokens_cnts[calc_index(num_experts, i, tid)] += + tokens_cnts[calc_index(num_experts, i - 1, tid)]; + } + } + + // __syncthreads(); + if(tid == 0) + { + cumsum[0] = 0; + for(int i = 1; i <= num_experts; ++i) + { + auto current_units = [&]() { + index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] + + unit_size_mdiv.divisor - 1; + index_t y_ = unit_size_mdiv.div(x_); + return max(y_, 1) * unit_size_mdiv.divisor; + }(); + cumsum[i] = cumsum[i - 1] + current_units; + } + *p_total_tokens_post_pad = cumsum[num_experts]; + } + __syncthreads(); + if(tid < num_experts) + { + for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor) + { + p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; + } + } + +#pragma unroll Problem_::InternalLoadUnroll + for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) + { + index_t expert_id = topk_id[i]; + index_t rank_post_pad = + tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id]; + p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); + p_sorted_weights[rank_post_pad] = weights[i]; + ++tokens_cnts[calc_index(num_experts, tid, expert_id)]; + } + + const index_t prefill_token = topk_mdiv.div(numel); + if(tid < num_experts) + { + index_t expert_offset = + cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)]; + while(expert_offset < cumsum[tid + 1]) + { + p_sorted_token_ids[expert_offset] = prefill_token; + p_sorted_weights[expert_offset] = static_cast(0.0); + expert_offset++; + } + } + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if(blockIdx.x > 0) + { + if(kargs.p_moe_buf) + { + moe_buf_set_zero_kernel(reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes); + } + return; + } + const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; + extern __shared__ char smem[]; + return moe_align_block_size_kernel(static_cast(kargs.p_topk_ids), + static_cast(kargs.p_weights), + static_cast(kargs.p_sorted_token_ids), + static_cast(kargs.p_sorted_weights), + static_cast(kargs.p_sorted_expert_ids), + static_cast(kargs.p_total_tokens_post_pad), + kargs.num_experts, + kargs.tokens_per_thread, + numel, + kargs.unit_size_mdiv, + kargs.topk_mdiv, + smem); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp new file mode 100644 index 0000000000..bbd47352d4 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include +#include + +#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW +#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0 +#endif + +namespace ck_tile { + +// template +// struct MoeSortingPipeline +// { +// // TODO: this kernel only support warp per row +// using Problem = remove_cvref_t; +// using Policy = remove_cvref_t; +// using WeightType = typename Problem::WeightType; + +// template +// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window, +// const WeightWindow& weight_window, +// index_t* p_sorted_token_ids, +// WeightType* p_sorted_weights, +// index_t* p_sorted_expert_ids, +// index_t* p_total_tokens_post_pad, +// const index_t num_experts, +// const index_t unit_size, +// const size_t numel, +// const index_t topk) +// { +// } +// }; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp new file mode 100644 index 0000000000..f5218a93e2 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/softmax.hpp" +#include "ck_tile/ops/topk.hpp" + +namespace ck_tile { + +struct MoeSortingPolicy +{ +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp new file mode 100644 index 0000000000..adde59e356 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +namespace ck_tile { + +template +struct MoeSortingProblem +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/moe_sorting.hpp b/include/ck_tile/ops/moe_sorting.hpp new file mode 100644 index 0000000000..b74607f061 --- /dev/null +++ b/include/ck_tile/ops/moe_sorting.hpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp"