From d93efe1b614d2045e5e24f29494fda06107f8db6 Mon Sep 17 00:00:00 2001 From: Amir Ghamarian Date: Sun, 29 Mar 2026 18:06:03 +0000 Subject: [PATCH] Add fused topk_softmax_decode kernel for M=1 MoE decode New CK tile kernel variant that fuses topk_softmax and moe_sorting into a single kernel launch for the decode case (M=1, single token). The pipeline inlines the topk loop with results in shared memory (no global scratch), then thread 0 emits moe_sorting-compatible packed output. Includes CMake target tile_example_topk_softmax_decode with built-in comparison benchmark against the separate topk+sorting baseline. Validated on gfx950, E=8..1024, k=1..8, bf16/fp16. Made-with: Cursor --- .../ck_tile/09_topk_softmax/CMakeLists.txt | 19 +- .../09_topk_softmax/topk_softmax_decode.cpp | 314 ++++++++++++++++++ .../topk_softmax_decode_api.cpp | 99 ++++++ .../topk_softmax_decode_api.hpp | 24 ++ include/ck_tile/ops/topk_softmax.hpp | 2 + .../kernel/topk_softmax_decode_kernel.hpp | 129 +++++++ .../pipeline/topk_softmax_decode_pipeline.hpp | 244 ++++++++++++++ 7 files changed, 828 insertions(+), 3 deletions(-) create mode 100644 example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp create mode 100644 example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp create mode 100644 example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp create mode 100644 include/ck_tile/ops/topk_softmax/kernel/topk_softmax_decode_kernel.hpp create mode 100644 include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_decode_pipeline.hpp diff --git a/example/ck_tile/09_topk_softmax/CMakeLists.txt b/example/ck_tile/09_topk_softmax/CMakeLists.txt index cce2c53ba4..669e794e33 100644 --- a/example/ck_tile/09_topk_softmax/CMakeLists.txt +++ b/example/ck_tile/09_topk_softmax/CMakeLists.txt @@ -1,11 +1,24 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp) -target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) - set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp) +target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS}) + +add_executable(tile_example_topk_softmax_decode + topk_softmax_decode.cpp + topk_softmax_decode_api.cpp + topk_softmax_api.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/moe_sorting_api.cpp) +target_include_directories(tile_example_topk_softmax_decode PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/ + ${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/) +target_compile_definitions(tile_example_topk_softmax_decode PRIVATE + CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID=1 + MOE_SORTING_FMOE_2D_BUF=1) +target_compile_options(tile_example_topk_softmax_decode PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS}) diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp new file mode 100644 index 0000000000..74d86b13e0 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp @@ -0,0 +1,314 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "topk_softmax_decode_api.hpp" +#include "topk_softmax_api.hpp" +#include "moe_sorting_api.hpp" + +#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID +#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 +#endif +#include "ck_tile/host/reference/reference_moe_sorting.hpp" + +// CPU reference: softmax -> topk -> moe_sorting +template +bool reference_fused(const ck_tile::HostTensor& x_host, + ck_tile::index_t topk, + ck_tile::index_t num_experts, + ck_tile::index_t unit_size, + ck_tile::HostTensor& ref_sorted_ids, + ck_tile::HostTensor& ref_sorted_weights, + ck_tile::HostTensor& ref_sorted_expert_ids, + ck_tile::index_t& ref_unit_cnt) +{ + auto probs = ck_tile::reference_softmax(x_host); + + ck_tile::HostTensor topk_vals({1, topk}); + ck_tile::HostTensor topk_idxs({1, topk}); + ck_tile::reference_topk(probs, topk_vals, topk_idxs, topk); + + ck_tile::HostTensor local_expert_mask({num_experts}); + ref_unit_cnt = 0; + ck_tile::reference_moe_sorting( + topk_idxs, + topk_vals, + local_expert_mask, + ref_sorted_ids, + ref_sorted_weights, + ref_sorted_expert_ids, + ref_unit_cnt, + num_experts, + unit_size, + 1, + false, + true); + + return true; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "do CPU validation") + .insert("pr_i", "bf16", "input data type: fp16/bf16") + .insert("e", "128", "number of experts") + .insert("k", "8", "topk") + .insert("unit", "32", "unit_size (block_size_M)") + .insert("model_dim", "7168", "model dimension for moe_buf zeroing") + .insert("seed", "-1", "random seed, -1 = random") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run_test(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string pr = args.get_str("pr_i"); + int experts = args.get_int("e"); + int topk = args.get_int("k"); + int unit_size = args.get_int("unit"); + int model_dim = args.get_int("model_dim"); + int seed = args.get_int("seed"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + + if(seed < 0) + seed = std::time(nullptr); + + if(topk > experts) + { + printf("topk %d > experts %d, skip\n", topk, experts); + return false; + } + + int tokens = 1; + int max_num_tokens_padded = topk + experts * unit_size - topk; + int max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) / unit_size; + + // Host tensors + ck_tile::HostTensor x_host({1, experts}); + { + auto rng = ck_tile::FillUniformDistribution_Unique{ + -5.f, 5.f, static_cast(seed)}; + ck_tile::HostTensor row({experts}); + rng(row); + std::copy(row.begin(), row.end(), x_host.begin()); + } + + // ---------- Device buffers (shared) ---------- + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + x_dev.ToDevice(x_host.data()); + + // ---------- Fused kernel buffers ---------- + ck_tile::DeviceMem fused_sorted_ids(max_num_tokens_padded * sizeof(IndexType)); + ck_tile::DeviceMem fused_sorted_weights(max_num_tokens_padded * sizeof(WeightType)); + ck_tile::DeviceMem fused_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType)); + ck_tile::DeviceMem fused_num_valid(2 * sizeof(IndexType)); + ck_tile::DeviceMem fused_moe_buf(model_dim * sizeof(WeightType)); + { + std::vector ones(model_dim, 1.0f); + fused_moe_buf.ToDevice(ones.data()); + } + + // ---------- Two-kernel baseline buffers ---------- + ck_tile::DeviceMem topk_w_dev(topk * sizeof(WeightType)); + ck_tile::DeviceMem topk_i_dev(topk * sizeof(IndexType)); + ck_tile::DeviceMem base_sorted_ids(max_num_tokens_padded * sizeof(IndexType)); + ck_tile::DeviceMem base_sorted_weights(max_num_tokens_padded * sizeof(WeightType)); + ck_tile::DeviceMem base_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType)); + ck_tile::DeviceMem base_num_valid(2 * sizeof(IndexType)); + ck_tile::DeviceMem base_moe_buf(model_dim * sizeof(WeightType)); + int ws_size = moe_sorting_get_workspace_size(tokens, experts, topk, 0); + ck_tile::DeviceMem base_ws(ws_size > 0 ? ws_size : 1); + if(ws_size > 0) + base_ws.SetZero(); + + ck_tile::stream_config sc{nullptr, true, 0, warmup, repeat}; + + // ====================== Fused kernel ====================== + topk_softmax_decode_trait fused_trait{pr, "fp32", experts, "softmax"}; + topk_softmax_decode_kargs fused_karg{ + x_dev.GetDeviceBuffer(), + experts, + topk, + experts, + true, + fused_sorted_ids.GetDeviceBuffer(), + fused_sorted_weights.GetDeviceBuffer(), + fused_sorted_expert_ids.GetDeviceBuffer(), + fused_num_valid.GetDeviceBuffer(), + fused_moe_buf.GetDeviceBuffer(), + unit_size, + model_dim, + static_cast(sizeof(WeightType))}; + + float ms_fused = topk_softmax_decode(fused_trait, fused_karg, sc); + + // ============= Two-kernel baseline: topk + moe_sorting ============= + topk_softmax_trait ts_trait{pr, "fp32", experts, "softmax"}; + topk_softmax_kargs ts_karg{ + x_dev.GetDeviceBuffer(), + topk_w_dev.GetDeviceBuffer(), + topk_i_dev.GetDeviceBuffer(), + tokens, + experts, + topk, + experts, + topk}; + + moe_sorting_trait ms_trait{"int32", "fp32", false, true, 0}; + moe_sorting_args ms_arg{ + topk_i_dev.GetDeviceBuffer(), + topk_w_dev.GetDeviceBuffer(), + nullptr, + nullptr, + base_sorted_ids.GetDeviceBuffer(), + base_sorted_weights.GetDeviceBuffer(), + base_sorted_expert_ids.GetDeviceBuffer(), + base_num_valid.GetDeviceBuffer(), + base_moe_buf.GetDeviceBuffer(), + ws_size > 0 ? base_ws.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + experts, + topk, + model_dim, + static_cast(sizeof(WeightType))}; + + // Time the two kernels together using launch_kernel with two lambdas + auto sc_sub = ck_tile::stream_config{nullptr, false, 0, 0, 1}; + float ms_baseline = ck_tile::launch_kernel( + sc, + [&](const ck_tile::stream_config&) { topk_softmax(ts_trait, ts_karg, sc_sub); }, + [&](const ck_tile::stream_config&) { moe_sorting(ms_trait, ms_arg, sc_sub); }); + + float speedup = (ms_baseline > 0 && ms_fused > 0) ? ms_baseline / ms_fused : 0; + printf("[%s] E:%d, k:%d, unit:%d | fused:%.4fms baseline(topk+sort):%.4fms speedup:%.2fx", + pr.c_str(), + experts, + topk, + unit_size, + ms_fused, + ms_baseline, + speedup); + + if(ms_fused < 0 || ms_baseline < 0) + { + printf(" (not supported)\n"); + return false; + } + + // ====================== Validation ====================== + bool pass = true; + if(validate) + { + ck_tile::HostTensor sorted_ids_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_weights_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_expert_ids_host({max_num_m_blocks}); + ck_tile::HostTensor num_valid_host({2}); + std::vector moe_buf_host(model_dim); + + fused_sorted_ids.FromDevice(sorted_ids_host.data()); + fused_sorted_weights.FromDevice(sorted_weights_host.data()); + fused_sorted_expert_ids.FromDevice(sorted_expert_ids_host.data()); + fused_num_valid.FromDevice(num_valid_host.data()); + fused_moe_buf.FromDevice(moe_buf_host.data()); + + ck_tile::HostTensor ref_sorted_ids({max_num_tokens_padded}); + ck_tile::HostTensor ref_sorted_weights({max_num_tokens_padded}); + ck_tile::HostTensor ref_sorted_expert_ids({max_num_m_blocks}); + IndexType sentinel = static_cast((1 & 0x00ffffff) | ((topk & 0xff) << 24)); + std::fill(ref_sorted_ids.begin(), ref_sorted_ids.end(), sentinel); + std::fill(ref_sorted_weights.begin(), ref_sorted_weights.end(), WeightType(0)); + std::fill(ref_sorted_expert_ids.begin(), ref_sorted_expert_ids.end(), -1); + + ck_tile::index_t ref_unit_cnt = 0; + reference_fused( + x_host, topk, experts, unit_size, + ref_sorted_ids, ref_sorted_weights, ref_sorted_expert_ids, ref_unit_cnt); + + int num_valid_padded = num_valid_host(0); + int num_valid_tokens = num_valid_host(1); + + if(num_valid_padded != ref_unit_cnt) + { + printf(" FAIL:num_valid[0] got %d ref %d;", num_valid_padded, ref_unit_cnt); + pass = false; + } + if(num_valid_tokens != 1) + { + printf(" FAIL:num_valid[1] got %d;", num_valid_tokens); + pass = false; + } + + int n_tiles = num_valid_padded / unit_size; + for(int i = 1; i < n_tiles; i++) + { + if(sorted_expert_ids_host(i) < sorted_expert_ids_host(i - 1)) + { + printf(" FAIL:expert_ids not ascending;"); + pass = false; + break; + } + } + + WeightType wsum = 0; + for(int i = 0; i < topk; i++) + wsum += sorted_weights_host(i * unit_size); + if(std::abs(wsum - 1.0f) > 1e-3f) + { + printf(" FAIL:wsum=%.6f;", static_cast(wsum)); + pass = false; + } + + bool buf_zeroed = std::all_of( + moe_buf_host.begin(), moe_buf_host.end(), [](float v) { return v == 0.0f; }); + if(!buf_zeroed) + { + printf(" FAIL:moe_buf not zeroed;"); + pass = false; + } + } + + printf(" valid:%s\n", pass ? "y" : "n"); + fflush(stdout); + return pass; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + + std::string pr = args.get_str("pr_i"); + bool r = true; + + if(pr == "fp16") + r &= run_test(args); + else if(pr == "bf16") + r &= run_test(args); + else + { + printf("unsupported pr_i: %s\n", pr.c_str()); + return -1; + } + + return r ? 0 : -1; +} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp new file mode 100644 index 0000000000..e593556099 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp @@ -0,0 +1,99 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "topk_softmax_decode_api.hpp" + +#define TOPK_SOFTMAX_DECODE_DISPATCH(experts_, use_softmax_) \ + constexpr ck_tile::index_t ts_experts = experts_; \ + constexpr bool ts_use_softmax = use_softmax_; \ + using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem; \ + using ts_pipeline = ck_tile::TopkSoftmaxDecodePipeline; \ + \ + using kernel = ck_tile::TopkSoftmaxDecodeKernel; \ + \ + auto kargs = kernel::MakeKargs(a); \ + \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ + \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ + \ + return ave_time; + +#define TOPK_SOFTMAX_DECODE_EXPERT_LADDER(use_softmax_) \ + if(t.experts <= 8) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(8, use_softmax_) \ + } \ + else if(t.experts <= 16) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(16, use_softmax_) \ + } \ + else if(t.experts <= 32) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(32, use_softmax_) \ + } \ + else if(t.experts <= 64) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(64, use_softmax_) \ + } \ + else if(t.experts <= 128) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(128, use_softmax_) \ + } \ + else if(t.experts <= 192) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(192, use_softmax_) \ + } \ + else if(t.experts <= 256) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(256, use_softmax_) \ + } \ + else if(t.experts <= 512) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(512, use_softmax_) \ + } \ + else if(t.experts <= 1024) \ + { \ + TOPK_SOFTMAX_DECODE_DISPATCH(1024, use_softmax_)\ + } + +float topk_softmax_decode(topk_softmax_decode_trait t, + topk_softmax_decode_kargs a, + ck_tile::stream_config s) +{ + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true) + } + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax") + { + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true) + } + else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false) + } + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false) + } + return -1; +} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp b/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp new file mode 100644 index 0000000000..73df8af421 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp @@ -0,0 +1,24 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/topk_softmax.hpp" +#include + +struct topk_softmax_decode_trait +{ + std::string input_type; + std::string weight_type; // currently always float + int experts; + std::string activation; // "softmax" or "sigmoid" +}; + +struct topk_softmax_decode_kargs : public ck_tile::TopkSoftmaxDecodeHostArgs +{ +}; + +float topk_softmax_decode(topk_softmax_decode_trait t, + topk_softmax_decode_kargs a, + ck_tile::stream_config s); diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 7afce1708b..137249bdaf 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -3,9 +3,11 @@ #pragma once #include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp" +#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_decode_kernel.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_decode_pipeline.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_decode_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_decode_kernel.hpp new file mode 100644 index 0000000000..6e7910061d --- /dev/null +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_decode_kernel.hpp @@ -0,0 +1,129 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include + +namespace ck_tile { + +struct TopkSoftmaxDecodeHostArgs +{ + // Input (gating logits) + const void* p_input; + index_t num_experts; + index_t topk; + index_t stride_input; + bool renormalize; + + // Output (moe_sorting format) + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_sorted_expert_ids; + void* p_total_tokens_post_pad; + + // moe_buf zeroing + void* p_moe_buf; + index_t unit_size; + index_t moe_buf_interm_dim; + index_t moe_buf_elem_bytes; +}; + +template +struct TopkSoftmaxDecodeKernel +{ + using Pipeline = remove_cvref_t; + using Problem = remove_cvref_t; + + using InputType = typename Problem::InputType; + using WeightType = typename Problem::WeightType; + using IndexType = typename Problem::IndexType; + + static constexpr index_t kBlockSize = Problem::BlockSize; + + struct Kargs + { + const void* p_input; + index_t num_experts; + index_t topk; + index_t stride_input; + bool renormalize; + + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_sorted_expert_ids; + void* p_total_tokens_post_pad; + + void* p_moe_buf; + index_t unit_size; + index_t moe_buf_interm_dim; + index_t moe_buf_elem_bytes; + }; + + using Hargs = TopkSoftmaxDecodeHostArgs; + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return dim3(1); } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_input = h.p_input; + k.num_experts = h.num_experts; + k.topk = h.topk; + k.stride_input = h.stride_input; + k.renormalize = h.renormalize; + k.p_sorted_token_ids = h.p_sorted_token_ids; + k.p_sorted_weights = h.p_sorted_weights; + k.p_sorted_expert_ids = h.p_sorted_expert_ids; + k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; + k.p_moe_buf = h.p_moe_buf; + k.unit_size = h.unit_size; + k.moe_buf_interm_dim = h.moe_buf_interm_dim; + k.moe_buf_elem_bytes = h.moe_buf_elem_bytes; + return k; + } + + CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + constexpr index_t num_rows = 1; + + const auto input_window = [&]() { + const InputType* p_input = reinterpret_cast(kargs.p_input); + auto tmp = make_naive_tensor_view( + p_input, + make_tuple(num_rows, kargs.num_experts), + make_tuple(kargs.stride_input, 1), + number{}, + number<1>{}); + auto view = pad_tensor_view( + tmp, + make_tuple(number{}, number{}), + sequence<0, 1>{}); + return make_tile_window( + view, + make_tuple(number{}, number{}), + {0, 0}); + }(); + + Pipeline{}(input_window, + kargs.num_experts, + kargs.topk, + kargs.renormalize, + reinterpret_cast(kargs.p_sorted_token_ids), + reinterpret_cast(kargs.p_sorted_weights), + reinterpret_cast(kargs.p_sorted_expert_ids), + reinterpret_cast(kargs.p_total_tokens_post_pad), + kargs.p_moe_buf, + kargs.unit_size, + kargs.moe_buf_interm_dim, + kargs.moe_buf_elem_bytes); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_decode_pipeline.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_decode_pipeline.hpp new file mode 100644 index 0000000000..a94ebf7dbe --- /dev/null +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_decode_pipeline.hpp @@ -0,0 +1,244 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" +#include +#include + +namespace ck_tile { + +// M=1 decode variant of TopkSoftmaxWarpPerRowPipeline. +// Inlines the topk loop (same algorithm as BlockTopkStream2D) but writes +// results to shared memory instead of global memory tile windows. Then +// thread 0 directly emits moe_sorting-compatible sorted outputs. +template +struct TopkSoftmaxDecodePipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using WeightType = typename Problem::WeightType; + using IndexType = typename Problem::IndexType; + + static constexpr index_t kMaxTopk = 8; + + // Same struct as BlockTopkStream2D::ArgmaxPacket + struct ArgmaxPacket + { + WeightType arg; + index_t value; + }; + + template + CK_TILE_DEVICE auto operator()(const InputWindow& input_window, + index_t experts, + index_t k, + bool renormalize, + IndexType* __restrict__ p_sorted_token_ids, + WeightType* __restrict__ p_sorted_weights, + IndexType* __restrict__ p_sorted_expert_ids, + IndexType* __restrict__ p_total_tokens_post_pad, + void* __restrict__ p_moe_buf, + index_t unit_size, + index_t moe_buf_interm_dim, + index_t moe_buf_elem_bytes) + { + auto inp_win = make_tile_window_linear( + input_window, Policy::template MakeInputDistribution(), sequence<0, 1>{}); + + auto softmax = Policy::template GetSoftmax(); + + // --- Phase 1: Load input and compute softmax/sigmoid --- + auto x = load_tile(inp_win); + + auto w = [&]() { + auto w_ = make_static_distributed_tensor(x.get_tile_distribution()); + auto w_f = [&](auto idx) { + w_(idx) = type_convert(x(idx)); + const auto x_indices = + get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx); + const auto current_expert = x_indices.at(number<1>{}); + w_(idx) = + current_expert >= experts ? -numeric::infinity() : w_(idx); + if constexpr(!Problem::ActivationIsSoftmax) + { + w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx))); + } + }; + tile_sweeper ts{w_, w_f}; + ts(); + return w_; + }(); + + auto y = [&]() { + if constexpr(Problem::ActivationIsSoftmax) + return softmax(w); + else + return w; + }(); + + // --- Phase 2: Inline topk loop (same as BlockTopkStream2D but → shared mem) --- + __shared__ IndexType s_expert_ids[kMaxTopk]; + __shared__ WeightType s_weights[kMaxTopk]; + __shared__ IndexType s_original_slots[kMaxTopk]; + + const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) { + return e0.arg > e1.arg ? e0 : e1; + }; + + // Exactly mirrors BlockTopkStream2D::operator() lines 45-100 + decltype(y) y_tmp = y; + constexpr auto span_2d = decltype(y_tmp)::get_distributed_spans(); + + for(index_t i_k = 0; i_k < k; i_k++) + { + // Build ArgmaxPacket distributed tensor (same as BlockTopkStream2D lines 56-71) + auto packet = [&]() { + auto tmp = + make_static_distributed_tensor(y.get_tile_distribution()); + + sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + tmp.get_tile_distribution(), make_tuple(idx0, idx1)); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + ArgmaxPacket t; + t.arg = y_tmp(i_j_idx); + t.value = tile_idx.at(number<1>{}); + tmp(i_j_idx) = t; + }); + }); + return tmp; + }(); + + // Reduce to find argmax (same as BlockTopkStream2D lines 73-75) + auto argmax_init = + ArgmaxPacket{-numeric::infinity(), 0}; + auto r = + block_tile_reduce(packet, sequence<1>{}, f_argmax, argmax_init); + block_tile_reduce_xor_sync(r, f_argmax); + + // Extract result and store to shared memory instead of tile windows. + // After xor_sync, all threads have the same argmax. We use the same + // r(i_j_idx) access pattern as BlockTopkStream2D line 82. + sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + ArgmaxPacket winner = r(i_j_idx); + + if(threadIdx.x == 0) + { + s_expert_ids[i_k] = static_cast(winner.value); + s_weights[i_k] = winner.arg; + s_original_slots[i_k] = static_cast(i_k); + } + }); + }); + + // Mask out selected expert (same as BlockTopkStream2D lines 89-100) + sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + y.get_tile_distribution(), make_tuple(idx0, idx1)); + auto col_id = tile_idx.at(number<1>{}); + + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + y_tmp(i_j_idx) = (col_id == r(i_j_idx).value) + ? -numeric::infinity() + : y_tmp(i_j_idx); + }); + }); + __syncthreads(); + } + + // --- Phase 3: Produce sorted outputs (thread 0 only, trivial for M=1) --- + if(threadIdx.x == 0) + { + if(renormalize) + { + WeightType sum = WeightType(0); + for(index_t i = 0; i < k; i++) + sum += s_weights[i]; + if(sum != WeightType(0)) + { + WeightType inv_sum = WeightType(1) / sum; + for(index_t i = 0; i < k; i++) + s_weights[i] *= inv_sum; + } + } + + // Sort by expert_id (ascending). k <= 8, insertion sort. + for(index_t i = 1; i < k; i++) + { + IndexType key_eid = s_expert_ids[i]; + WeightType key_w = s_weights[i]; + IndexType key_slot = s_original_slots[i]; + index_t j = i - 1; + while(j >= 0 && s_expert_ids[j] > key_eid) + { + s_expert_ids[j + 1] = s_expert_ids[j]; + s_weights[j + 1] = s_weights[j]; + s_original_slots[j + 1] = s_original_slots[j]; + j--; + } + s_expert_ids[j + 1] = key_eid; + s_weights[j + 1] = key_w; + s_original_slots[j + 1] = key_slot; + } + + constexpr index_t num_tokens = 1; + index_t write_offset = 0; + index_t expert_tile_idx = 0; + + IndexType sentinel = + static_cast((num_tokens & 0x00ffffff) | ((k & 0xff) << 24)); + + for(index_t i = 0; i < k; i++) + { + IndexType expert_id = s_expert_ids[i]; + WeightType weight = s_weights[i]; + IndexType topk_slot = s_original_slots[i]; + + IndexType packed_id = + static_cast((0 & 0x00ffffff) | ((topk_slot & 0xff) << 24)); + + p_sorted_token_ids[write_offset] = packed_id; + p_sorted_weights[write_offset] = weight; + + for(index_t p = 1; p < unit_size; p++) + { + p_sorted_token_ids[write_offset + p] = sentinel; + p_sorted_weights[write_offset + p] = WeightType(0); + } + + p_sorted_expert_ids[expert_tile_idx] = expert_id; + + write_offset += unit_size; + expert_tile_idx++; + } + + p_total_tokens_post_pad[0] = static_cast(k * unit_size); + p_total_tokens_post_pad[1] = static_cast(num_tokens); + } + + // --- Phase 4: Zero moe_buf cooperatively --- + if(p_moe_buf != nullptr) + { + const index_t total_bytes = moe_buf_interm_dim * moe_buf_elem_bytes; + const index_t total_elems = total_bytes / 16; + + using vector_type = ext_vector_t; + vector_type* p_buf = reinterpret_cast(p_moe_buf); + auto zero_ = vector_type{0}; + + for(index_t i = threadIdx.x; i < total_elems; i += blockDim.x) + { + p_buf[i] = zero_; + } + } + } +}; +} // namespace ck_tile