[CK_TILE] Top-K with Sigmoid kernel (#3062)

* Add sigmoid option to topk_softmax

* fix formatting

* add to changelog

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Use else if

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Sami Remes
2025-10-28 17:54:06 +00:00
committed by GitHub
parent 6f58d6e457
commit 515e283091
10 changed files with 319 additions and 83 deletions

View File

@@ -83,6 +83,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_sigmoid(const ck_tile::HostTensor<InputType>& x,
ck_tile::HostTensor<WeightType>& y_values,
ck_tile::HostTensor<IndexType>& y_indices,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
// topk only - no need to apply the sigmoid first
auto x_fp32 = x.template CopyAsType<float>();
reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted);
// apply sigmoid
std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) {
return WeightType(1) / (WeightType(1) + exp(-value));
});
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
@@ -133,7 +153,8 @@ auto create_args(int argc, char* argv[])
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "topk_softmax.json", "json file name to dump results");
.insert("jsonfile", "topk_softmax.json", "json file name to dump results")
.insert("activation", "softmax", "activation function to use: softmax or sigmoid");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -154,6 +175,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
std::string activation = args.get_str("activation");
if(stride_input < 0)
{
@@ -204,7 +226,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
x_dev.ToDevice(x_host.data());
topk_softmax_trait trait{input_prec, weight_prec, experts};
topk_softmax_trait trait{input_prec, weight_prec, experts, activation};
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
value_dev.GetDeviceBuffer(),
@@ -221,7 +243,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
warmup,
repeat};
auto ms = topk_softmax(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ",
input_prec.c_str(),
weight_prec.c_str(),
tokens,
@@ -229,6 +251,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
topk,
stride_input,
stride_output,
activation.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
@@ -247,8 +270,20 @@ bool test_topk_softmax(ck_tile::ArgParser args)
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
if(activation == "softmax")
{
reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
}
else if(activation == "sigmoid")
{
reference_topk_sigmoid<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
}
else
{
throw std::runtime_error("unsupported activation type: " + activation);
}
auto [rtol, atol] = get_elimit<InputType>("");
for(int i_t = 0; i_t < tokens; i_t++)

View File

@@ -3,27 +3,31 @@
#include "topk_softmax_api.hpp"
#define TOPK_SOFTMAX_DISPATCH(experts_) \
constexpr ck_tile::index_t ts_experts = experts_; \
using ts_problem = ck_tile:: \
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \
constexpr ck_tile::index_t ts_experts = experts_; \
constexpr bool ts_use_softmax = use_softmax_; \
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
ts_weight_type, \
ts_index_type, \
ts_experts, \
ts_use_softmax>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32")
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
@@ -31,36 +35,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
TOPK_SOFTMAX_DISPATCH(8, true)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
TOPK_SOFTMAX_DISPATCH(16, true)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
TOPK_SOFTMAX_DISPATCH(32, true)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
TOPK_SOFTMAX_DISPATCH(64, true)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
TOPK_SOFTMAX_DISPATCH(192, true)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
@@ -68,27 +72,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
TOPK_SOFTMAX_DISPATCH(8, true)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
TOPK_SOFTMAX_DISPATCH(16, true)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
TOPK_SOFTMAX_DISPATCH(32, true)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
TOPK_SOFTMAX_DISPATCH(64, true)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
TOPK_SOFTMAX_DISPATCH(192, true)
}
#endif
}
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8, false)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16, false)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32, false)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64, false)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192, false)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8, false)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16, false)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32, false)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64, false)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192, false)
}
#endif
}

View File

@@ -12,6 +12,7 @@ struct topk_softmax_trait
std::string input_type;
std::string weight_type; // currently always float
int experts;
std::string activation; // "softmax" or "sigmoid"
};
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs