mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Top-K with Sigmoid kernel (#3062)
* Add sigmoid option to topk_softmax * fix formatting * add to changelog * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Use else if Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,7 @@ struct TopkSoftmaxHostArgs
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
index_t stride_output; // row stride for output/indices, at least topk
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
@@ -45,7 +45,7 @@ struct TopkSoftmaxKernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
index_t stride_output; // row stride for output/indices, at least topk
|
||||
};
|
||||
|
||||
using Kargs = TopkSoftmaxKargs;
|
||||
|
||||
@@ -90,6 +90,11 @@ struct TopkSoftmaxWarpPerRowPipeline
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
w_(idx) =
|
||||
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
|
||||
if constexpr(!Problem::ActivationIsSoftmax)
|
||||
{
|
||||
// sigmoid can be pre-computed already here if not using softmax
|
||||
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
|
||||
}
|
||||
};
|
||||
tile_sweeper ts{w_, w_f};
|
||||
ts();
|
||||
@@ -97,10 +102,16 @@ struct TopkSoftmaxWarpPerRowPipeline
|
||||
#endif
|
||||
}();
|
||||
|
||||
// softmax
|
||||
auto y = softmax(w);
|
||||
|
||||
topk(y, out_win, idx_win, k);
|
||||
if constexpr(Problem::ActivationIsSoftmax)
|
||||
{
|
||||
auto y = softmax(w);
|
||||
topk(y, out_win, idx_win, k);
|
||||
}
|
||||
else
|
||||
{
|
||||
// sigmoid was already pre-computed above, so only do topk now
|
||||
topk(w, out_win, idx_win, k);
|
||||
}
|
||||
|
||||
// check exit
|
||||
if constexpr(Problem::LaunchType == 0)
|
||||
|
||||
@@ -13,10 +13,11 @@ template <typename InputType_,
|
||||
typename WeightType_,
|
||||
typename IndexType_,
|
||||
index_t Experts_,
|
||||
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
|
||||
index_t BytesPerIssue_ = sizeof(InputType_),
|
||||
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
|
||||
index_t BlockSize_ = 256>
|
||||
bool ActivationIsSoftmax_ = true, // false: sigmoid
|
||||
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
|
||||
index_t BytesPerIssue_ = sizeof(InputType_),
|
||||
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
|
||||
index_t BlockSize_ = 256>
|
||||
struct TopkSoftmaxWarpPerRowProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
@@ -31,6 +32,8 @@ struct TopkSoftmaxWarpPerRowProblem
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
|
||||
static constexpr bool ActivationIsSoftmax = ActivationIsSoftmax_;
|
||||
|
||||
static_assert(BytesPerIssue % sizeof(InputType) == 0);
|
||||
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
|
||||
static_assert(Experts % VectorSize == 0);
|
||||
|
||||
Reference in New Issue
Block a user