Add fused topk_softmax_decode kernel for M=1 MoE decode

New CK tile kernel variant that fuses topk_softmax and moe_sorting into
a single kernel launch for the decode case (M=1, single token). The
pipeline inlines the topk loop with results in shared memory (no global
scratch), then thread 0 emits moe_sorting-compatible packed output.

Includes CMake target tile_example_topk_softmax_decode with built-in
comparison benchmark against the separate topk+sorting baseline.

Validated on gfx950, E=8..1024, k=1..8, bf16/fp16.

Made-with: Cursor
This commit is contained in:
Amir Ghamarian
2026-03-29 18:06:03 +00:00
parent a7ded14537
commit d93efe1b61
7 changed files with 828 additions and 3 deletions

View File

@@ -1,11 +1,24 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
add_executable(tile_example_topk_softmax_decode
topk_softmax_decode.cpp
topk_softmax_decode_api.cpp
topk_softmax_api.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/moe_sorting_api.cpp)
target_include_directories(tile_example_topk_softmax_decode PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/)
target_compile_definitions(tile_example_topk_softmax_decode PRIVATE
CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID=1
MOE_SORTING_FMOE_2D_BUF=1)
target_compile_options(tile_example_topk_softmax_decode PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})

View File

@@ -0,0 +1,314 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <time.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_decode_api.hpp"
#include "topk_softmax_api.hpp"
#include "moe_sorting_api.hpp"
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
// CPU reference: softmax -> topk -> moe_sorting
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
bool reference_fused(const ck_tile::HostTensor<InputType>& x_host,
ck_tile::index_t topk,
ck_tile::index_t num_experts,
ck_tile::index_t unit_size,
ck_tile::HostTensor<IndexType>& ref_sorted_ids,
ck_tile::HostTensor<WeightType>& ref_sorted_weights,
ck_tile::HostTensor<IndexType>& ref_sorted_expert_ids,
ck_tile::index_t& ref_unit_cnt)
{
auto probs = ck_tile::reference_softmax<InputType, WeightType, WeightType>(x_host);
ck_tile::HostTensor<WeightType> topk_vals({1, topk});
ck_tile::HostTensor<IndexType> topk_idxs({1, topk});
ck_tile::reference_topk(probs, topk_vals, topk_idxs, topk);
ck_tile::HostTensor<IndexType> local_expert_mask({num_experts});
ref_unit_cnt = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(
topk_idxs,
topk_vals,
local_expert_mask,
ref_sorted_ids,
ref_sorted_weights,
ref_sorted_expert_ids,
ref_unit_cnt,
num_experts,
unit_size,
1,
false,
true);
return true;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "do CPU validation")
.insert("pr_i", "bf16", "input data type: fp16/bf16")
.insert("e", "128", "number of experts")
.insert("k", "8", "topk")
.insert("unit", "32", "unit_size (block_size_M)")
.insert("model_dim", "7168", "model dimension for moe_buf zeroing")
.insert("seed", "-1", "random seed, -1 = random")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string pr = args.get_str("pr_i");
int experts = args.get_int("e");
int topk = args.get_int("k");
int unit_size = args.get_int("unit");
int model_dim = args.get_int("model_dim");
int seed = args.get_int("seed");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
if(seed < 0)
seed = std::time(nullptr);
if(topk > experts)
{
printf("topk %d > experts %d, skip\n", topk, experts);
return false;
}
int tokens = 1;
int max_num_tokens_padded = topk + experts * unit_size - topk;
int max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) / unit_size;
// Host tensors
ck_tile::HostTensor<InputType> x_host({1, experts});
{
auto rng = ck_tile::FillUniformDistribution_Unique<InputType>{
-5.f, 5.f, static_cast<uint32_t>(seed)};
ck_tile::HostTensor<InputType> row({experts});
rng(row);
std::copy(row.begin(), row.end(), x_host.begin());
}
// ---------- Device buffers (shared) ----------
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
// ---------- Fused kernel buffers ----------
ck_tile::DeviceMem fused_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
ck_tile::DeviceMem fused_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
ck_tile::DeviceMem fused_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
ck_tile::DeviceMem fused_num_valid(2 * sizeof(IndexType));
ck_tile::DeviceMem fused_moe_buf(model_dim * sizeof(WeightType));
{
std::vector<float> ones(model_dim, 1.0f);
fused_moe_buf.ToDevice(ones.data());
}
// ---------- Two-kernel baseline buffers ----------
ck_tile::DeviceMem topk_w_dev(topk * sizeof(WeightType));
ck_tile::DeviceMem topk_i_dev(topk * sizeof(IndexType));
ck_tile::DeviceMem base_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
ck_tile::DeviceMem base_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
ck_tile::DeviceMem base_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
ck_tile::DeviceMem base_num_valid(2 * sizeof(IndexType));
ck_tile::DeviceMem base_moe_buf(model_dim * sizeof(WeightType));
int ws_size = moe_sorting_get_workspace_size(tokens, experts, topk, 0);
ck_tile::DeviceMem base_ws(ws_size > 0 ? ws_size : 1);
if(ws_size > 0)
base_ws.SetZero();
ck_tile::stream_config sc{nullptr, true, 0, warmup, repeat};
// ====================== Fused kernel ======================
topk_softmax_decode_trait fused_trait{pr, "fp32", experts, "softmax"};
topk_softmax_decode_kargs fused_karg{
x_dev.GetDeviceBuffer(),
experts,
topk,
experts,
true,
fused_sorted_ids.GetDeviceBuffer(),
fused_sorted_weights.GetDeviceBuffer(),
fused_sorted_expert_ids.GetDeviceBuffer(),
fused_num_valid.GetDeviceBuffer(),
fused_moe_buf.GetDeviceBuffer(),
unit_size,
model_dim,
static_cast<int>(sizeof(WeightType))};
float ms_fused = topk_softmax_decode(fused_trait, fused_karg, sc);
// ============= Two-kernel baseline: topk + moe_sorting =============
topk_softmax_trait ts_trait{pr, "fp32", experts, "softmax"};
topk_softmax_kargs ts_karg{
x_dev.GetDeviceBuffer(),
topk_w_dev.GetDeviceBuffer(),
topk_i_dev.GetDeviceBuffer(),
tokens,
experts,
topk,
experts,
topk};
moe_sorting_trait ms_trait{"int32", "fp32", false, true, 0};
moe_sorting_args ms_arg{
topk_i_dev.GetDeviceBuffer(),
topk_w_dev.GetDeviceBuffer(),
nullptr,
nullptr,
base_sorted_ids.GetDeviceBuffer(),
base_sorted_weights.GetDeviceBuffer(),
base_sorted_expert_ids.GetDeviceBuffer(),
base_num_valid.GetDeviceBuffer(),
base_moe_buf.GetDeviceBuffer(),
ws_size > 0 ? base_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
experts,
topk,
model_dim,
static_cast<int>(sizeof(WeightType))};
// Time the two kernels together using launch_kernel with two lambdas
auto sc_sub = ck_tile::stream_config{nullptr, false, 0, 0, 1};
float ms_baseline = ck_tile::launch_kernel(
sc,
[&](const ck_tile::stream_config&) { topk_softmax(ts_trait, ts_karg, sc_sub); },
[&](const ck_tile::stream_config&) { moe_sorting(ms_trait, ms_arg, sc_sub); });
float speedup = (ms_baseline > 0 && ms_fused > 0) ? ms_baseline / ms_fused : 0;
printf("[%s] E:%d, k:%d, unit:%d | fused:%.4fms baseline(topk+sort):%.4fms speedup:%.2fx",
pr.c_str(),
experts,
topk,
unit_size,
ms_fused,
ms_baseline,
speedup);
if(ms_fused < 0 || ms_baseline < 0)
{
printf(" (not supported)\n");
return false;
}
// ====================== Validation ======================
bool pass = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_host({max_num_tokens_padded});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_num_tokens_padded});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_num_m_blocks});
ck_tile::HostTensor<IndexType> num_valid_host({2});
std::vector<float> moe_buf_host(model_dim);
fused_sorted_ids.FromDevice(sorted_ids_host.data());
fused_sorted_weights.FromDevice(sorted_weights_host.data());
fused_sorted_expert_ids.FromDevice(sorted_expert_ids_host.data());
fused_num_valid.FromDevice(num_valid_host.data());
fused_moe_buf.FromDevice(moe_buf_host.data());
ck_tile::HostTensor<IndexType> ref_sorted_ids({max_num_tokens_padded});
ck_tile::HostTensor<WeightType> ref_sorted_weights({max_num_tokens_padded});
ck_tile::HostTensor<IndexType> ref_sorted_expert_ids({max_num_m_blocks});
IndexType sentinel = static_cast<uint32_t>((1 & 0x00ffffff) | ((topk & 0xff) << 24));
std::fill(ref_sorted_ids.begin(), ref_sorted_ids.end(), sentinel);
std::fill(ref_sorted_weights.begin(), ref_sorted_weights.end(), WeightType(0));
std::fill(ref_sorted_expert_ids.begin(), ref_sorted_expert_ids.end(), -1);
ck_tile::index_t ref_unit_cnt = 0;
reference_fused<InputType, WeightType, IndexType>(
x_host, topk, experts, unit_size,
ref_sorted_ids, ref_sorted_weights, ref_sorted_expert_ids, ref_unit_cnt);
int num_valid_padded = num_valid_host(0);
int num_valid_tokens = num_valid_host(1);
if(num_valid_padded != ref_unit_cnt)
{
printf(" FAIL:num_valid[0] got %d ref %d;", num_valid_padded, ref_unit_cnt);
pass = false;
}
if(num_valid_tokens != 1)
{
printf(" FAIL:num_valid[1] got %d;", num_valid_tokens);
pass = false;
}
int n_tiles = num_valid_padded / unit_size;
for(int i = 1; i < n_tiles; i++)
{
if(sorted_expert_ids_host(i) < sorted_expert_ids_host(i - 1))
{
printf(" FAIL:expert_ids not ascending;");
pass = false;
break;
}
}
WeightType wsum = 0;
for(int i = 0; i < topk; i++)
wsum += sorted_weights_host(i * unit_size);
if(std::abs(wsum - 1.0f) > 1e-3f)
{
printf(" FAIL:wsum=%.6f;", static_cast<float>(wsum));
pass = false;
}
bool buf_zeroed = std::all_of(
moe_buf_host.begin(), moe_buf_host.end(), [](float v) { return v == 0.0f; });
if(!buf_zeroed)
{
printf(" FAIL:moe_buf not zeroed;");
pass = false;
}
}
printf(" valid:%s\n", pass ? "y" : "n");
fflush(stdout);
return pass;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string pr = args.get_str("pr_i");
bool r = true;
if(pr == "fp16")
r &= run_test<ck_tile::fp16_t, float>(args);
else if(pr == "bf16")
r &= run_test<ck_tile::bf16_t, float>(args);
else
{
printf("unsupported pr_i: %s\n", pr.c_str());
return -1;
}
return r ? 0 : -1;
}

View File

@@ -0,0 +1,99 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "topk_softmax_decode_api.hpp"
#define TOPK_SOFTMAX_DECODE_DISPATCH(experts_, use_softmax_) \
constexpr ck_tile::index_t ts_experts = experts_; \
constexpr bool ts_use_softmax = use_softmax_; \
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
ts_weight_type, \
ts_index_type, \
ts_experts, \
ts_use_softmax>; \
using ts_pipeline = ck_tile::TopkSoftmaxDecodePipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxDecodeKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;
#define TOPK_SOFTMAX_DECODE_EXPERT_LADDER(use_softmax_) \
if(t.experts <= 8) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(8, use_softmax_) \
} \
else if(t.experts <= 16) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(16, use_softmax_) \
} \
else if(t.experts <= 32) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(32, use_softmax_) \
} \
else if(t.experts <= 64) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(64, use_softmax_) \
} \
else if(t.experts <= 128) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(128, use_softmax_) \
} \
else if(t.experts <= 192) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(192, use_softmax_) \
} \
else if(t.experts <= 256) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(256, use_softmax_) \
} \
else if(t.experts <= 512) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(512, use_softmax_) \
} \
else if(t.experts <= 1024) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(1024, use_softmax_)\
}
float topk_softmax_decode(topk_softmax_decode_trait t,
topk_softmax_decode_kargs a,
ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
}
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
{
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
}
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
}
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
}
return -1;
}

View File

@@ -0,0 +1,24 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct topk_softmax_decode_trait
{
std::string input_type;
std::string weight_type; // currently always float
int experts;
std::string activation; // "softmax" or "sigmoid"
};
struct topk_softmax_decode_kargs : public ck_tile::TopkSoftmaxDecodeHostArgs
{
};
float topk_softmax_decode(topk_softmax_decode_trait t,
topk_softmax_decode_kargs a,
ck_tile::stream_config s);

View File

@@ -3,9 +3,11 @@
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_decode_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_decode_pipeline.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"

View File

@@ -0,0 +1,129 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct TopkSoftmaxDecodeHostArgs
{
// Input (gating logits)
const void* p_input;
index_t num_experts;
index_t topk;
index_t stride_input;
bool renormalize;
// Output (moe_sorting format)
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
void* p_total_tokens_post_pad;
// moe_buf zeroing
void* p_moe_buf;
index_t unit_size;
index_t moe_buf_interm_dim;
index_t moe_buf_elem_bytes;
};
template <typename Pipeline_>
struct TopkSoftmaxDecodeKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using InputType = typename Problem::InputType;
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
static constexpr index_t kBlockSize = Problem::BlockSize;
struct Kargs
{
const void* p_input;
index_t num_experts;
index_t topk;
index_t stride_input;
bool renormalize;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
void* p_total_tokens_post_pad;
void* p_moe_buf;
index_t unit_size;
index_t moe_buf_interm_dim;
index_t moe_buf_elem_bytes;
};
using Hargs = TopkSoftmaxDecodeHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return dim3(1); }
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.num_experts = h.num_experts;
k.topk = h.topk;
k.stride_input = h.stride_input;
k.renormalize = h.renormalize;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
k.p_moe_buf = h.p_moe_buf;
k.unit_size = h.unit_size;
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
constexpr index_t num_rows = 1;
const auto input_window = [&]() {
const InputType* p_input = reinterpret_cast<const InputType*>(kargs.p_input);
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_input,
make_tuple(num_rows, kargs.num_experts),
make_tuple(kargs.stride_input, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view = pad_tensor_view(
tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
sequence<0, 1>{});
return make_tile_window(
view,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
{0, 0});
}();
Pipeline{}(input_window,
kargs.num_experts,
kargs.topk,
kargs.renormalize,
reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids),
reinterpret_cast<WeightType*>(kargs.p_sorted_weights),
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids),
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.p_moe_buf,
kargs.unit_size,
kargs.moe_buf_interm_dim,
kargs.moe_buf_elem_bytes);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,244 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
// M=1 decode variant of TopkSoftmaxWarpPerRowPipeline.
// Inlines the topk loop (same algorithm as BlockTopkStream2D) but writes
// results to shared memory instead of global memory tile windows. Then
// thread 0 directly emits moe_sorting-compatible sorted outputs.
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
struct TopkSoftmaxDecodePipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
static constexpr index_t kMaxTopk = 8;
// Same struct as BlockTopkStream2D::ArgmaxPacket
struct ArgmaxPacket
{
WeightType arg;
index_t value;
};
template <typename InputWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
index_t experts,
index_t k,
bool renormalize,
IndexType* __restrict__ p_sorted_token_ids,
WeightType* __restrict__ p_sorted_weights,
IndexType* __restrict__ p_sorted_expert_ids,
IndexType* __restrict__ p_total_tokens_post_pad,
void* __restrict__ p_moe_buf,
index_t unit_size,
index_t moe_buf_interm_dim,
index_t moe_buf_elem_bytes)
{
auto inp_win = make_tile_window_linear(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
auto softmax = Policy::template GetSoftmax<Problem>();
// --- Phase 1: Load input and compute softmax/sigmoid ---
auto x = load_tile(inp_win);
auto w = [&]() {
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
auto w_f = [&](auto idx) {
w_(idx) = type_convert<WeightType>(x(idx));
const auto x_indices =
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
const auto current_expert = x_indices.at(number<1>{});
w_(idx) =
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
if constexpr(!Problem::ActivationIsSoftmax)
{
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
}
};
tile_sweeper<decltype(w_), decltype(w_f)> ts{w_, w_f};
ts();
return w_;
}();
auto y = [&]() {
if constexpr(Problem::ActivationIsSoftmax)
return softmax(w);
else
return w;
}();
// --- Phase 2: Inline topk loop (same as BlockTopkStream2D but → shared mem) ---
__shared__ IndexType s_expert_ids[kMaxTopk];
__shared__ WeightType s_weights[kMaxTopk];
__shared__ IndexType s_original_slots[kMaxTopk];
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
return e0.arg > e1.arg ? e0 : e1;
};
// Exactly mirrors BlockTopkStream2D::operator() lines 45-100
decltype(y) y_tmp = y;
constexpr auto span_2d = decltype(y_tmp)::get_distributed_spans();
for(index_t i_k = 0; i_k < k; i_k++)
{
// Build ArgmaxPacket distributed tensor (same as BlockTopkStream2D lines 56-71)
auto packet = [&]() {
auto tmp =
make_static_distributed_tensor<ArgmaxPacket>(y.get_tile_distribution());
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket t;
t.arg = y_tmp(i_j_idx);
t.value = tile_idx.at(number<1>{});
tmp(i_j_idx) = t;
});
});
return tmp;
}();
// Reduce to find argmax (same as BlockTopkStream2D lines 73-75)
auto argmax_init =
ArgmaxPacket{-numeric<WeightType>::infinity(), 0};
auto r =
block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
block_tile_reduce_xor_sync(r, f_argmax);
// Extract result and store to shared memory instead of tile windows.
// After xor_sync, all threads have the same argmax. We use the same
// r(i_j_idx) access pattern as BlockTopkStream2D line 82.
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket winner = r(i_j_idx);
if(threadIdx.x == 0)
{
s_expert_ids[i_k] = static_cast<IndexType>(winner.value);
s_weights[i_k] = winner.arg;
s_original_slots[i_k] = static_cast<IndexType>(i_k);
}
});
});
// Mask out selected expert (same as BlockTopkStream2D lines 89-100)
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
y.get_tile_distribution(), make_tuple(idx0, idx1));
auto col_id = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
y_tmp(i_j_idx) = (col_id == r(i_j_idx).value)
? -numeric<WeightType>::infinity()
: y_tmp(i_j_idx);
});
});
__syncthreads();
}
// --- Phase 3: Produce sorted outputs (thread 0 only, trivial for M=1) ---
if(threadIdx.x == 0)
{
if(renormalize)
{
WeightType sum = WeightType(0);
for(index_t i = 0; i < k; i++)
sum += s_weights[i];
if(sum != WeightType(0))
{
WeightType inv_sum = WeightType(1) / sum;
for(index_t i = 0; i < k; i++)
s_weights[i] *= inv_sum;
}
}
// Sort by expert_id (ascending). k <= 8, insertion sort.
for(index_t i = 1; i < k; i++)
{
IndexType key_eid = s_expert_ids[i];
WeightType key_w = s_weights[i];
IndexType key_slot = s_original_slots[i];
index_t j = i - 1;
while(j >= 0 && s_expert_ids[j] > key_eid)
{
s_expert_ids[j + 1] = s_expert_ids[j];
s_weights[j + 1] = s_weights[j];
s_original_slots[j + 1] = s_original_slots[j];
j--;
}
s_expert_ids[j + 1] = key_eid;
s_weights[j + 1] = key_w;
s_original_slots[j + 1] = key_slot;
}
constexpr index_t num_tokens = 1;
index_t write_offset = 0;
index_t expert_tile_idx = 0;
IndexType sentinel =
static_cast<uint32_t>((num_tokens & 0x00ffffff) | ((k & 0xff) << 24));
for(index_t i = 0; i < k; i++)
{
IndexType expert_id = s_expert_ids[i];
WeightType weight = s_weights[i];
IndexType topk_slot = s_original_slots[i];
IndexType packed_id =
static_cast<uint32_t>((0 & 0x00ffffff) | ((topk_slot & 0xff) << 24));
p_sorted_token_ids[write_offset] = packed_id;
p_sorted_weights[write_offset] = weight;
for(index_t p = 1; p < unit_size; p++)
{
p_sorted_token_ids[write_offset + p] = sentinel;
p_sorted_weights[write_offset + p] = WeightType(0);
}
p_sorted_expert_ids[expert_tile_idx] = expert_id;
write_offset += unit_size;
expert_tile_idx++;
}
p_total_tokens_post_pad[0] = static_cast<IndexType>(k * unit_size);
p_total_tokens_post_pad[1] = static_cast<IndexType>(num_tokens);
}
// --- Phase 4: Zero moe_buf cooperatively ---
if(p_moe_buf != nullptr)
{
const index_t total_bytes = moe_buf_interm_dim * moe_buf_elem_bytes;
const index_t total_elems = total_bytes / 16;
using vector_type = ext_vector_t<index_t, 4>;
vector_type* p_buf = reinterpret_cast<vector_type*>(p_moe_buf);
auto zero_ = vector_type{0};
for(index_t i = threadIdx.x; i < total_elems; i += blockDim.x)
{
p_buf[i] = zero_;
}
}
}
};
} // namespace ck_tile