mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[ck_tile] refactor reduce kernel (#3257)
* refactor reduce kernel - Rename Reduce kernel as per convention - Move kept_dim and reduce_dims from runtime to compile-time parameters - Update Reduce2dProblem template to include KeptDim, ReduceDims, and Rank - Remove IsSupportedArgument validation function as it's unnecessary. Not using the GuaranteedLastDimensionVectorStride while making tensor view or descriptor which removes the bounds enforced earlier. We still calculate and use vector size. - Update reduce example to demonstrate NCHW->NHW reduction with non-contiguous support - Update tests Kernel now handles both contiguous and non-contiguous memory layout. * fix compile errors
This commit is contained in:
committed by
GitHub
parent
92653168c2
commit
ea10a78203
@@ -12,6 +12,9 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ReduceOp_,
|
||||
typename KeptDim_,
|
||||
typename ReduceDims_,
|
||||
index_t Rank_,
|
||||
bool OutputIndex_ = false>
|
||||
struct Reduce2dProblem
|
||||
{
|
||||
@@ -20,7 +23,11 @@ struct Reduce2dProblem
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
using KeptDim = remove_cvref_t<KeptDim_>;
|
||||
using ReduceDims = remove_cvref_t<ReduceDims_>;
|
||||
|
||||
static constexpr index_t Rank = Rank_;
|
||||
static constexpr index_t NumReduceDim = ReduceDims::size();
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
Reference in New Issue
Block a user