[CK_TILE] Top-K with Sigmoid kernel (#3062)

* Add sigmoid option to topk_softmax

* fix formatting

* add to changelog

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Use else if

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Sami Remes
2025-10-28 17:54:06 +00:00
committed by GitHub
parent 6f58d6e457
commit 515e283091
10 changed files with 319 additions and 83 deletions

View File

@@ -21,7 +21,7 @@ struct TopkSoftmaxHostArgs
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
index_t stride_output; // row stride for output/indices, at least topk
};
template <typename Pipeline_>
@@ -45,7 +45,7 @@ struct TopkSoftmaxKernel
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
index_t stride_output; // row stride for output/indices, at least topk
};
using Kargs = TopkSoftmaxKargs;

View File

@@ -90,6 +90,11 @@ struct TopkSoftmaxWarpPerRowPipeline
const auto current_expert = x_indices.at(number<1>{});
w_(idx) =
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
if constexpr(!Problem::ActivationIsSoftmax)
{
// sigmoid can be pre-computed already here if not using softmax
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
}
};
tile_sweeper ts{w_, w_f};
ts();
@@ -97,10 +102,16 @@ struct TopkSoftmaxWarpPerRowPipeline
#endif
}();
// softmax
auto y = softmax(w);
topk(y, out_win, idx_win, k);
if constexpr(Problem::ActivationIsSoftmax)
{
auto y = softmax(w);
topk(y, out_win, idx_win, k);
}
else
{
// sigmoid was already pre-computed above, so only do topk now
topk(w, out_win, idx_win, k);
}
// check exit
if constexpr(Problem::LaunchType == 0)

View File

@@ -13,10 +13,11 @@ template <typename InputType_,
typename WeightType_,
typename IndexType_,
index_t Experts_,
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_),
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
index_t BlockSize_ = 256>
bool ActivationIsSoftmax_ = true, // false: sigmoid
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_),
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
index_t BlockSize_ = 256>
struct TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
@@ -31,6 +32,8 @@ struct TopkSoftmaxWarpPerRowProblem
static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t WarpSize = get_warp_size();
static constexpr bool ActivationIsSoftmax = ActivationIsSoftmax_;
static_assert(BytesPerIssue % sizeof(InputType) == 0);
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
static_assert(Experts % VectorSize == 0);