mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
Add fused topk_softmax_decode kernel for M=1 MoE decode
New CK tile kernel variant that fuses topk_softmax and moe_sorting into a single kernel launch for the decode case (M=1, single token). The pipeline inlines the topk loop with results in shared memory (no global scratch), then thread 0 emits moe_sorting-compatible packed output. Includes CMake target tile_example_topk_softmax_decode with built-in comparison benchmark against the separate topk+sorting baseline. Validated on gfx950, E=8..1024, k=1..8, bf16/fp16. Made-with: Cursor
This commit is contained in:
@@ -1,11 +1,24 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_topk_softmax_decode
|
||||
topk_softmax_decode.cpp
|
||||
topk_softmax_decode_api.cpp
|
||||
topk_softmax_api.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/moe_sorting_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax_decode PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/)
|
||||
target_compile_definitions(tile_example_topk_softmax_decode PRIVATE
|
||||
CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID=1
|
||||
MOE_SORTING_FMOE_2D_BUF=1)
|
||||
target_compile_options(tile_example_topk_softmax_decode PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
|
||||
314
example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp
Normal file
314
example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp
Normal file
@@ -0,0 +1,314 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <time.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "topk_softmax_decode_api.hpp"
|
||||
#include "topk_softmax_api.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
|
||||
// CPU reference: softmax -> topk -> moe_sorting
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool reference_fused(const ck_tile::HostTensor<InputType>& x_host,
|
||||
ck_tile::index_t topk,
|
||||
ck_tile::index_t num_experts,
|
||||
ck_tile::index_t unit_size,
|
||||
ck_tile::HostTensor<IndexType>& ref_sorted_ids,
|
||||
ck_tile::HostTensor<WeightType>& ref_sorted_weights,
|
||||
ck_tile::HostTensor<IndexType>& ref_sorted_expert_ids,
|
||||
ck_tile::index_t& ref_unit_cnt)
|
||||
{
|
||||
auto probs = ck_tile::reference_softmax<InputType, WeightType, WeightType>(x_host);
|
||||
|
||||
ck_tile::HostTensor<WeightType> topk_vals({1, topk});
|
||||
ck_tile::HostTensor<IndexType> topk_idxs({1, topk});
|
||||
ck_tile::reference_topk(probs, topk_vals, topk_idxs, topk);
|
||||
|
||||
ck_tile::HostTensor<IndexType> local_expert_mask({num_experts});
|
||||
ref_unit_cnt = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(
|
||||
topk_idxs,
|
||||
topk_vals,
|
||||
local_expert_mask,
|
||||
ref_sorted_ids,
|
||||
ref_sorted_weights,
|
||||
ref_sorted_expert_ids,
|
||||
ref_unit_cnt,
|
||||
num_experts,
|
||||
unit_size,
|
||||
1,
|
||||
false,
|
||||
true);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "do CPU validation")
|
||||
.insert("pr_i", "bf16", "input data type: fp16/bf16")
|
||||
.insert("e", "128", "number of experts")
|
||||
.insert("k", "8", "topk")
|
||||
.insert("unit", "32", "unit_size (block_size_M)")
|
||||
.insert("model_dim", "7168", "model dimension for moe_buf zeroing")
|
||||
.insert("seed", "-1", "random seed, -1 = random")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool run_test(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string pr = args.get_str("pr_i");
|
||||
int experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int unit_size = args.get_int("unit");
|
||||
int model_dim = args.get_int("model_dim");
|
||||
int seed = args.get_int("seed");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
|
||||
if(seed < 0)
|
||||
seed = std::time(nullptr);
|
||||
|
||||
if(topk > experts)
|
||||
{
|
||||
printf("topk %d > experts %d, skip\n", topk, experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
int tokens = 1;
|
||||
int max_num_tokens_padded = topk + experts * unit_size - topk;
|
||||
int max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) / unit_size;
|
||||
|
||||
// Host tensors
|
||||
ck_tile::HostTensor<InputType> x_host({1, experts});
|
||||
{
|
||||
auto rng = ck_tile::FillUniformDistribution_Unique<InputType>{
|
||||
-5.f, 5.f, static_cast<uint32_t>(seed)};
|
||||
ck_tile::HostTensor<InputType> row({experts});
|
||||
rng(row);
|
||||
std::copy(row.begin(), row.end(), x_host.begin());
|
||||
}
|
||||
|
||||
// ---------- Device buffers (shared) ----------
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
// ---------- Fused kernel buffers ----------
|
||||
ck_tile::DeviceMem fused_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
|
||||
ck_tile::DeviceMem fused_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_num_valid(2 * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_moe_buf(model_dim * sizeof(WeightType));
|
||||
{
|
||||
std::vector<float> ones(model_dim, 1.0f);
|
||||
fused_moe_buf.ToDevice(ones.data());
|
||||
}
|
||||
|
||||
// ---------- Two-kernel baseline buffers ----------
|
||||
ck_tile::DeviceMem topk_w_dev(topk * sizeof(WeightType));
|
||||
ck_tile::DeviceMem topk_i_dev(topk * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
|
||||
ck_tile::DeviceMem base_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_num_valid(2 * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_moe_buf(model_dim * sizeof(WeightType));
|
||||
int ws_size = moe_sorting_get_workspace_size(tokens, experts, topk, 0);
|
||||
ck_tile::DeviceMem base_ws(ws_size > 0 ? ws_size : 1);
|
||||
if(ws_size > 0)
|
||||
base_ws.SetZero();
|
||||
|
||||
ck_tile::stream_config sc{nullptr, true, 0, warmup, repeat};
|
||||
|
||||
// ====================== Fused kernel ======================
|
||||
topk_softmax_decode_trait fused_trait{pr, "fp32", experts, "softmax"};
|
||||
topk_softmax_decode_kargs fused_karg{
|
||||
x_dev.GetDeviceBuffer(),
|
||||
experts,
|
||||
topk,
|
||||
experts,
|
||||
true,
|
||||
fused_sorted_ids.GetDeviceBuffer(),
|
||||
fused_sorted_weights.GetDeviceBuffer(),
|
||||
fused_sorted_expert_ids.GetDeviceBuffer(),
|
||||
fused_num_valid.GetDeviceBuffer(),
|
||||
fused_moe_buf.GetDeviceBuffer(),
|
||||
unit_size,
|
||||
model_dim,
|
||||
static_cast<int>(sizeof(WeightType))};
|
||||
|
||||
float ms_fused = topk_softmax_decode(fused_trait, fused_karg, sc);
|
||||
|
||||
// ============= Two-kernel baseline: topk + moe_sorting =============
|
||||
topk_softmax_trait ts_trait{pr, "fp32", experts, "softmax"};
|
||||
topk_softmax_kargs ts_karg{
|
||||
x_dev.GetDeviceBuffer(),
|
||||
topk_w_dev.GetDeviceBuffer(),
|
||||
topk_i_dev.GetDeviceBuffer(),
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
experts,
|
||||
topk};
|
||||
|
||||
moe_sorting_trait ms_trait{"int32", "fp32", false, true, 0};
|
||||
moe_sorting_args ms_arg{
|
||||
topk_i_dev.GetDeviceBuffer(),
|
||||
topk_w_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
base_sorted_ids.GetDeviceBuffer(),
|
||||
base_sorted_weights.GetDeviceBuffer(),
|
||||
base_sorted_expert_ids.GetDeviceBuffer(),
|
||||
base_num_valid.GetDeviceBuffer(),
|
||||
base_moe_buf.GetDeviceBuffer(),
|
||||
ws_size > 0 ? base_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
experts,
|
||||
topk,
|
||||
model_dim,
|
||||
static_cast<int>(sizeof(WeightType))};
|
||||
|
||||
// Time the two kernels together using launch_kernel with two lambdas
|
||||
auto sc_sub = ck_tile::stream_config{nullptr, false, 0, 0, 1};
|
||||
float ms_baseline = ck_tile::launch_kernel(
|
||||
sc,
|
||||
[&](const ck_tile::stream_config&) { topk_softmax(ts_trait, ts_karg, sc_sub); },
|
||||
[&](const ck_tile::stream_config&) { moe_sorting(ms_trait, ms_arg, sc_sub); });
|
||||
|
||||
float speedup = (ms_baseline > 0 && ms_fused > 0) ? ms_baseline / ms_fused : 0;
|
||||
printf("[%s] E:%d, k:%d, unit:%d | fused:%.4fms baseline(topk+sort):%.4fms speedup:%.2fx",
|
||||
pr.c_str(),
|
||||
experts,
|
||||
topk,
|
||||
unit_size,
|
||||
ms_fused,
|
||||
ms_baseline,
|
||||
speedup);
|
||||
|
||||
if(ms_fused < 0 || ms_baseline < 0)
|
||||
{
|
||||
printf(" (not supported)\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// ====================== Validation ======================
|
||||
bool pass = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_num_m_blocks});
|
||||
ck_tile::HostTensor<IndexType> num_valid_host({2});
|
||||
std::vector<float> moe_buf_host(model_dim);
|
||||
|
||||
fused_sorted_ids.FromDevice(sorted_ids_host.data());
|
||||
fused_sorted_weights.FromDevice(sorted_weights_host.data());
|
||||
fused_sorted_expert_ids.FromDevice(sorted_expert_ids_host.data());
|
||||
fused_num_valid.FromDevice(num_valid_host.data());
|
||||
fused_moe_buf.FromDevice(moe_buf_host.data());
|
||||
|
||||
ck_tile::HostTensor<IndexType> ref_sorted_ids({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<WeightType> ref_sorted_weights({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexType> ref_sorted_expert_ids({max_num_m_blocks});
|
||||
IndexType sentinel = static_cast<uint32_t>((1 & 0x00ffffff) | ((topk & 0xff) << 24));
|
||||
std::fill(ref_sorted_ids.begin(), ref_sorted_ids.end(), sentinel);
|
||||
std::fill(ref_sorted_weights.begin(), ref_sorted_weights.end(), WeightType(0));
|
||||
std::fill(ref_sorted_expert_ids.begin(), ref_sorted_expert_ids.end(), -1);
|
||||
|
||||
ck_tile::index_t ref_unit_cnt = 0;
|
||||
reference_fused<InputType, WeightType, IndexType>(
|
||||
x_host, topk, experts, unit_size,
|
||||
ref_sorted_ids, ref_sorted_weights, ref_sorted_expert_ids, ref_unit_cnt);
|
||||
|
||||
int num_valid_padded = num_valid_host(0);
|
||||
int num_valid_tokens = num_valid_host(1);
|
||||
|
||||
if(num_valid_padded != ref_unit_cnt)
|
||||
{
|
||||
printf(" FAIL:num_valid[0] got %d ref %d;", num_valid_padded, ref_unit_cnt);
|
||||
pass = false;
|
||||
}
|
||||
if(num_valid_tokens != 1)
|
||||
{
|
||||
printf(" FAIL:num_valid[1] got %d;", num_valid_tokens);
|
||||
pass = false;
|
||||
}
|
||||
|
||||
int n_tiles = num_valid_padded / unit_size;
|
||||
for(int i = 1; i < n_tiles; i++)
|
||||
{
|
||||
if(sorted_expert_ids_host(i) < sorted_expert_ids_host(i - 1))
|
||||
{
|
||||
printf(" FAIL:expert_ids not ascending;");
|
||||
pass = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
WeightType wsum = 0;
|
||||
for(int i = 0; i < topk; i++)
|
||||
wsum += sorted_weights_host(i * unit_size);
|
||||
if(std::abs(wsum - 1.0f) > 1e-3f)
|
||||
{
|
||||
printf(" FAIL:wsum=%.6f;", static_cast<float>(wsum));
|
||||
pass = false;
|
||||
}
|
||||
|
||||
bool buf_zeroed = std::all_of(
|
||||
moe_buf_host.begin(), moe_buf_host.end(), [](float v) { return v == 0.0f; });
|
||||
if(!buf_zeroed)
|
||||
{
|
||||
printf(" FAIL:moe_buf not zeroed;");
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
printf(" valid:%s\n", pass ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string pr = args.get_str("pr_i");
|
||||
bool r = true;
|
||||
|
||||
if(pr == "fp16")
|
||||
r &= run_test<ck_tile::fp16_t, float>(args);
|
||||
else if(pr == "bf16")
|
||||
r &= run_test<ck_tile::bf16_t, float>(args);
|
||||
else
|
||||
{
|
||||
printf("unsupported pr_i: %s\n", pr.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
99
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp
Normal file
99
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "topk_softmax_decode_api.hpp"
|
||||
|
||||
#define TOPK_SOFTMAX_DECODE_DISPATCH(experts_, use_softmax_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
constexpr bool ts_use_softmax = use_softmax_; \
|
||||
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
|
||||
ts_weight_type, \
|
||||
ts_index_type, \
|
||||
ts_experts, \
|
||||
ts_use_softmax>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxDecodePipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxDecodeKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
#define TOPK_SOFTMAX_DECODE_EXPERT_LADDER(use_softmax_) \
|
||||
if(t.experts <= 8) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(8, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 16) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(16, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 32) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(32, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 64) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(64, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 128) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(128, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 192) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(192, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 256) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(256, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 512) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(512, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 1024) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(1024, use_softmax_)\
|
||||
}
|
||||
|
||||
float topk_softmax_decode(topk_softmax_decode_trait t,
|
||||
topk_softmax_decode_kargs a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
|
||||
}
|
||||
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
24
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp
Normal file
24
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/topk_softmax.hpp"
|
||||
#include <string>
|
||||
|
||||
struct topk_softmax_decode_trait
|
||||
{
|
||||
std::string input_type;
|
||||
std::string weight_type; // currently always float
|
||||
int experts;
|
||||
std::string activation; // "softmax" or "sigmoid"
|
||||
};
|
||||
|
||||
struct topk_softmax_decode_kargs : public ck_tile::TopkSoftmaxDecodeHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float topk_softmax_decode(topk_softmax_decode_trait t,
|
||||
topk_softmax_decode_kargs a,
|
||||
ck_tile::stream_config s);
|
||||
@@ -3,9 +3,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_decode_kernel.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_decode_pipeline.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TopkSoftmaxDecodeHostArgs
|
||||
{
|
||||
// Input (gating logits)
|
||||
const void* p_input;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input;
|
||||
bool renormalize;
|
||||
|
||||
// Output (moe_sorting format)
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
|
||||
// moe_buf zeroing
|
||||
void* p_moe_buf;
|
||||
index_t unit_size;
|
||||
index_t moe_buf_interm_dim;
|
||||
index_t moe_buf_elem_bytes;
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
struct TopkSoftmaxDecodeKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using InputType = typename Problem::InputType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockSize;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_input;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input;
|
||||
bool renormalize;
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
|
||||
void* p_moe_buf;
|
||||
index_t unit_size;
|
||||
index_t moe_buf_interm_dim;
|
||||
index_t moe_buf_elem_bytes;
|
||||
};
|
||||
|
||||
using Hargs = TopkSoftmaxDecodeHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return dim3(1); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_input = h.p_input;
|
||||
k.num_experts = h.num_experts;
|
||||
k.topk = h.topk;
|
||||
k.stride_input = h.stride_input;
|
||||
k.renormalize = h.renormalize;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.p_moe_buf = h.p_moe_buf;
|
||||
k.unit_size = h.unit_size;
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
constexpr index_t num_rows = 1;
|
||||
|
||||
const auto input_window = [&]() {
|
||||
const InputType* p_input = reinterpret_cast<const InputType*>(kargs.p_input);
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input,
|
||||
make_tuple(num_rows, kargs.num_experts),
|
||||
make_tuple(kargs.stride_input, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
auto view = pad_tensor_view(
|
||||
tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
sequence<0, 1>{});
|
||||
return make_tile_window(
|
||||
view,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
{0, 0});
|
||||
}();
|
||||
|
||||
Pipeline{}(input_window,
|
||||
kargs.num_experts,
|
||||
kargs.topk,
|
||||
kargs.renormalize,
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
reinterpret_cast<WeightType*>(kargs.p_sorted_weights),
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids),
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.p_moe_buf,
|
||||
kargs.unit_size,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,244 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// M=1 decode variant of TopkSoftmaxWarpPerRowPipeline.
|
||||
// Inlines the topk loop (same algorithm as BlockTopkStream2D) but writes
|
||||
// results to shared memory instead of global memory tile windows. Then
|
||||
// thread 0 directly emits moe_sorting-compatible sorted outputs.
|
||||
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
|
||||
struct TopkSoftmaxDecodePipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
static constexpr index_t kMaxTopk = 8;
|
||||
|
||||
// Same struct as BlockTopkStream2D::ArgmaxPacket
|
||||
struct ArgmaxPacket
|
||||
{
|
||||
WeightType arg;
|
||||
index_t value;
|
||||
};
|
||||
|
||||
template <typename InputWindow>
|
||||
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
|
||||
index_t experts,
|
||||
index_t k,
|
||||
bool renormalize,
|
||||
IndexType* __restrict__ p_sorted_token_ids,
|
||||
WeightType* __restrict__ p_sorted_weights,
|
||||
IndexType* __restrict__ p_sorted_expert_ids,
|
||||
IndexType* __restrict__ p_total_tokens_post_pad,
|
||||
void* __restrict__ p_moe_buf,
|
||||
index_t unit_size,
|
||||
index_t moe_buf_interm_dim,
|
||||
index_t moe_buf_elem_bytes)
|
||||
{
|
||||
auto inp_win = make_tile_window_linear(
|
||||
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
|
||||
|
||||
auto softmax = Policy::template GetSoftmax<Problem>();
|
||||
|
||||
// --- Phase 1: Load input and compute softmax/sigmoid ---
|
||||
auto x = load_tile(inp_win);
|
||||
|
||||
auto w = [&]() {
|
||||
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
|
||||
auto w_f = [&](auto idx) {
|
||||
w_(idx) = type_convert<WeightType>(x(idx));
|
||||
const auto x_indices =
|
||||
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
w_(idx) =
|
||||
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
|
||||
if constexpr(!Problem::ActivationIsSoftmax)
|
||||
{
|
||||
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
|
||||
}
|
||||
};
|
||||
tile_sweeper<decltype(w_), decltype(w_f)> ts{w_, w_f};
|
||||
ts();
|
||||
return w_;
|
||||
}();
|
||||
|
||||
auto y = [&]() {
|
||||
if constexpr(Problem::ActivationIsSoftmax)
|
||||
return softmax(w);
|
||||
else
|
||||
return w;
|
||||
}();
|
||||
|
||||
// --- Phase 2: Inline topk loop (same as BlockTopkStream2D but → shared mem) ---
|
||||
__shared__ IndexType s_expert_ids[kMaxTopk];
|
||||
__shared__ WeightType s_weights[kMaxTopk];
|
||||
__shared__ IndexType s_original_slots[kMaxTopk];
|
||||
|
||||
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
|
||||
return e0.arg > e1.arg ? e0 : e1;
|
||||
};
|
||||
|
||||
// Exactly mirrors BlockTopkStream2D::operator() lines 45-100
|
||||
decltype(y) y_tmp = y;
|
||||
constexpr auto span_2d = decltype(y_tmp)::get_distributed_spans();
|
||||
|
||||
for(index_t i_k = 0; i_k < k; i_k++)
|
||||
{
|
||||
// Build ArgmaxPacket distributed tensor (same as BlockTopkStream2D lines 56-71)
|
||||
auto packet = [&]() {
|
||||
auto tmp =
|
||||
make_static_distributed_tensor<ArgmaxPacket>(y.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket t;
|
||||
t.arg = y_tmp(i_j_idx);
|
||||
t.value = tile_idx.at(number<1>{});
|
||||
tmp(i_j_idx) = t;
|
||||
});
|
||||
});
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
// Reduce to find argmax (same as BlockTopkStream2D lines 73-75)
|
||||
auto argmax_init =
|
||||
ArgmaxPacket{-numeric<WeightType>::infinity(), 0};
|
||||
auto r =
|
||||
block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
|
||||
block_tile_reduce_xor_sync(r, f_argmax);
|
||||
|
||||
// Extract result and store to shared memory instead of tile windows.
|
||||
// After xor_sync, all threads have the same argmax. We use the same
|
||||
// r(i_j_idx) access pattern as BlockTopkStream2D line 82.
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket winner = r(i_j_idx);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
s_expert_ids[i_k] = static_cast<IndexType>(winner.value);
|
||||
s_weights[i_k] = winner.arg;
|
||||
s_original_slots[i_k] = static_cast<IndexType>(i_k);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Mask out selected expert (same as BlockTopkStream2D lines 89-100)
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
y.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
auto col_id = tile_idx.at(number<1>{});
|
||||
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
y_tmp(i_j_idx) = (col_id == r(i_j_idx).value)
|
||||
? -numeric<WeightType>::infinity()
|
||||
: y_tmp(i_j_idx);
|
||||
});
|
||||
});
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// --- Phase 3: Produce sorted outputs (thread 0 only, trivial for M=1) ---
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
if(renormalize)
|
||||
{
|
||||
WeightType sum = WeightType(0);
|
||||
for(index_t i = 0; i < k; i++)
|
||||
sum += s_weights[i];
|
||||
if(sum != WeightType(0))
|
||||
{
|
||||
WeightType inv_sum = WeightType(1) / sum;
|
||||
for(index_t i = 0; i < k; i++)
|
||||
s_weights[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by expert_id (ascending). k <= 8, insertion sort.
|
||||
for(index_t i = 1; i < k; i++)
|
||||
{
|
||||
IndexType key_eid = s_expert_ids[i];
|
||||
WeightType key_w = s_weights[i];
|
||||
IndexType key_slot = s_original_slots[i];
|
||||
index_t j = i - 1;
|
||||
while(j >= 0 && s_expert_ids[j] > key_eid)
|
||||
{
|
||||
s_expert_ids[j + 1] = s_expert_ids[j];
|
||||
s_weights[j + 1] = s_weights[j];
|
||||
s_original_slots[j + 1] = s_original_slots[j];
|
||||
j--;
|
||||
}
|
||||
s_expert_ids[j + 1] = key_eid;
|
||||
s_weights[j + 1] = key_w;
|
||||
s_original_slots[j + 1] = key_slot;
|
||||
}
|
||||
|
||||
constexpr index_t num_tokens = 1;
|
||||
index_t write_offset = 0;
|
||||
index_t expert_tile_idx = 0;
|
||||
|
||||
IndexType sentinel =
|
||||
static_cast<uint32_t>((num_tokens & 0x00ffffff) | ((k & 0xff) << 24));
|
||||
|
||||
for(index_t i = 0; i < k; i++)
|
||||
{
|
||||
IndexType expert_id = s_expert_ids[i];
|
||||
WeightType weight = s_weights[i];
|
||||
IndexType topk_slot = s_original_slots[i];
|
||||
|
||||
IndexType packed_id =
|
||||
static_cast<uint32_t>((0 & 0x00ffffff) | ((topk_slot & 0xff) << 24));
|
||||
|
||||
p_sorted_token_ids[write_offset] = packed_id;
|
||||
p_sorted_weights[write_offset] = weight;
|
||||
|
||||
for(index_t p = 1; p < unit_size; p++)
|
||||
{
|
||||
p_sorted_token_ids[write_offset + p] = sentinel;
|
||||
p_sorted_weights[write_offset + p] = WeightType(0);
|
||||
}
|
||||
|
||||
p_sorted_expert_ids[expert_tile_idx] = expert_id;
|
||||
|
||||
write_offset += unit_size;
|
||||
expert_tile_idx++;
|
||||
}
|
||||
|
||||
p_total_tokens_post_pad[0] = static_cast<IndexType>(k * unit_size);
|
||||
p_total_tokens_post_pad[1] = static_cast<IndexType>(num_tokens);
|
||||
}
|
||||
|
||||
// --- Phase 4: Zero moe_buf cooperatively ---
|
||||
if(p_moe_buf != nullptr)
|
||||
{
|
||||
const index_t total_bytes = moe_buf_interm_dim * moe_buf_elem_bytes;
|
||||
const index_t total_elems = total_bytes / 16;
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(p_moe_buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(index_t i = threadIdx.x; i < total_elems; i += blockDim.x)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user