[ck_tile] refactor reduce kernel (#3257)

* refactor reduce kernel

- Rename Reduce kernel as per convention

- Move kept_dim and reduce_dims from runtime to compile-time parameters

- Update Reduce2dProblem template to include KeptDim, ReduceDims, and
Rank

- Remove IsSupportedArgument validation function as it's unnecessary.
Not using the GuaranteedLastDimensionVectorStride while making tensor
view or descriptor which removes the bounds enforced earlier. We still
calculate and use vector size.

- Update reduce example to demonstrate NCHW->NHW reduction with
non-contiguous support

- Update tests

Kernel now handles both contiguous and non-contiguous memory layout.

* fix compile errors
This commit is contained in:
Yashvardhan Agarwal
2025-12-17 21:46:08 +02:00
committed by GitHub
parent 92653168c2
commit ea10a78203
5 changed files with 89 additions and 130 deletions

View File

@@ -12,6 +12,9 @@ template <typename XDataType_,
typename YDataType_,
typename BlockShape_,
typename ReduceOp_,
typename KeptDim_,
typename ReduceDims_,
index_t Rank_,
bool OutputIndex_ = false>
struct Reduce2dProblem
{
@@ -20,7 +23,11 @@ struct Reduce2dProblem
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using ReduceOp = ReduceOp_;
using KeptDim = remove_cvref_t<KeptDim_>;
using ReduceDims = remove_cvref_t<ReduceDims_>;
static constexpr index_t Rank = Rank_;
static constexpr index_t NumReduceDim = ReduceDims::size();
static constexpr bool kOutputIndex = OutputIndex_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;