mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
* refactor reduce kernel - Rename Reduce kernel as per convention - Move kept_dim and reduce_dims from runtime to compile-time parameters - Update Reduce2dProblem template to include KeptDim, ReduceDims, and Rank - Remove IsSupportedArgument validation function as it's unnecessary. Not using the GuaranteedLastDimensionVectorStride while making tensor view or descriptor which removes the bounds enforced earlier. We still calculate and use vector size. - Update reduce example to demonstrate NCHW->NHW reduction with non-contiguous support - Update tests Kernel now handles both contiguous and non-contiguous memory layout. * fix compile errors
37 lines
1.2 KiB
C++
37 lines
1.2 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
template <typename XDataType_,
|
|
typename ComputeDataType_,
|
|
typename YDataType_,
|
|
typename BlockShape_,
|
|
typename ReduceOp_,
|
|
typename KeptDim_,
|
|
typename ReduceDims_,
|
|
index_t Rank_,
|
|
bool OutputIndex_ = false>
|
|
struct Reduce2dProblem
|
|
{
|
|
using XDataType = remove_cvref_t<XDataType_>;
|
|
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
|
using YDataType = remove_cvref_t<YDataType_>;
|
|
using BlockShape = remove_cvref_t<BlockShape_>;
|
|
using ReduceOp = ReduceOp_;
|
|
using KeptDim = remove_cvref_t<KeptDim_>;
|
|
using ReduceDims = remove_cvref_t<ReduceDims_>;
|
|
|
|
static constexpr index_t Rank = Rank_;
|
|
static constexpr index_t NumReduceDim = ReduceDims::size();
|
|
static constexpr bool kOutputIndex = OutputIndex_;
|
|
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
|
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
|
};
|
|
|
|
} // namespace ck_tile
|