mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Update to the Reduction API and instances (#476)
* Simplify the macros for declaring and defining the add_device_reduce_instance_xxxx() instances
* Change the types of lengths and strides from std::vector to std::array for the reduction device interfaces
* Remove DeviceSoftmaxImpl's depending on DeviceReduceMultiblock
* Split the cpp and hpp files for reduction instances to enable more parallel compiling
* Remove the using of macros for declaring reduction instances and instance references
* Update to add_device_reduce_instance_xxxx templated functions
* Use ReduceOperation+InElementwiseOp+AccElementwiseOp to repace the ReduceOpId in defining add_reduce_instance_xxxx() templates
* Change return format
[ROCm/composable_kernel commit: dda3a0a10b]
This commit is contained in:
@@ -18,57 +18,61 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <int Rank, int NumReduceDim, int ReduceOpId, bool PropagateNan, bool UseIndex>
|
||||
template <index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
ReduceTensorOp ReduceOpId,
|
||||
bool PropagateNan,
|
||||
bool UseIndex>
|
||||
struct ReduceDescription
|
||||
{
|
||||
static constexpr int Rank_ = Rank;
|
||||
static constexpr int NumReduceDim_ = NumReduceDim;
|
||||
static constexpr int ReduceOpId_ = ReduceOpId;
|
||||
static constexpr int PropagateNan_ = PropagateNan;
|
||||
static constexpr int UseIndex_ = UseIndex;
|
||||
static constexpr index_t Rank_ = Rank;
|
||||
static constexpr index_t NumReduceDim_ = NumReduceDim;
|
||||
static constexpr ReduceTensorOp ReduceOpId_ = ReduceOpId;
|
||||
static constexpr bool PropagateNan_ = PropagateNan;
|
||||
static constexpr bool UseIndex_ = UseIndex;
|
||||
};
|
||||
|
||||
using reduce_description_instances =
|
||||
std::tuple<ReduceDescription<4, 3, 0, false, false>, // for ADD
|
||||
ReduceDescription<4, 4, 0, false, false>,
|
||||
ReduceDescription<4, 1, 0, false, false>,
|
||||
ReduceDescription<2, 1, 0, false, false>,
|
||||
std::tuple<ReduceDescription<4, 3, ReduceTensorOp::ADD, false, false>, // for ADD
|
||||
ReduceDescription<4, 4, ReduceTensorOp::ADD, false, false>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::ADD, false, false>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::ADD, false, false>,
|
||||
|
||||
ReduceDescription<4, 3, 5, false, false>, // for AVG
|
||||
ReduceDescription<4, 4, 5, false, false>,
|
||||
ReduceDescription<4, 1, 5, false, false>,
|
||||
ReduceDescription<2, 1, 5, false, false>,
|
||||
ReduceDescription<4, 3, ReduceTensorOp::AVG, false, false>, // for AVG
|
||||
ReduceDescription<4, 4, ReduceTensorOp::AVG, false, false>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::AVG, false, false>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::AVG, false, false>,
|
||||
|
||||
ReduceDescription<4, 3, 7, false, false>, // for NORM2
|
||||
ReduceDescription<4, 4, 7, false, false>,
|
||||
ReduceDescription<4, 1, 7, false, false>,
|
||||
ReduceDescription<2, 1, 7, false, false>,
|
||||
ReduceDescription<4, 3, ReduceTensorOp::NORM2, false, false>, // for NORM2
|
||||
ReduceDescription<4, 4, ReduceTensorOp::NORM2, false, false>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::NORM2, false, false>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::NORM2, false, false>,
|
||||
|
||||
ReduceDescription<4, 3, 2, false, false>, // for MIN
|
||||
ReduceDescription<4, 4, 2, false, false>,
|
||||
ReduceDescription<4, 1, 2, false, false>,
|
||||
ReduceDescription<2, 1, 2, false, false>,
|
||||
ReduceDescription<4, 3, 3, false, false>, // for MAX
|
||||
ReduceDescription<4, 4, 3, false, false>,
|
||||
ReduceDescription<4, 1, 3, false, false>,
|
||||
ReduceDescription<2, 1, 3, false, false>,
|
||||
ReduceDescription<4, 3, 4, false, false>, // for AMAX
|
||||
ReduceDescription<4, 4, 4, false, false>,
|
||||
ReduceDescription<4, 1, 4, false, false>,
|
||||
ReduceDescription<2, 1, 4, false, false>,
|
||||
ReduceDescription<4, 3, ReduceTensorOp::MIN, false, false>, // for MIN
|
||||
ReduceDescription<4, 4, ReduceTensorOp::MIN, false, false>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::MIN, false, false>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::MIN, false, false>,
|
||||
ReduceDescription<4, 3, ReduceTensorOp::MAX, false, false>, // for MAX
|
||||
ReduceDescription<4, 4, ReduceTensorOp::MAX, false, false>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::MAX, false, false>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::MAX, false, false>,
|
||||
ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, false>, // for AMAX
|
||||
ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, false>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, false>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, false>,
|
||||
|
||||
ReduceDescription<4, 3, 2, false, true>, // for MIN
|
||||
ReduceDescription<4, 4, 2, false, true>,
|
||||
ReduceDescription<4, 1, 2, false, true>,
|
||||
ReduceDescription<2, 1, 2, false, true>,
|
||||
ReduceDescription<4, 3, 3, false, true>, // for MAX
|
||||
ReduceDescription<4, 4, 3, false, true>,
|
||||
ReduceDescription<4, 1, 3, false, true>,
|
||||
ReduceDescription<2, 1, 3, false, true>,
|
||||
ReduceDescription<4, 3, 4, false, true>, // for AMAX
|
||||
ReduceDescription<4, 4, 4, false, true>,
|
||||
ReduceDescription<4, 1, 4, false, true>,
|
||||
ReduceDescription<2, 1, 4, false, true>>;
|
||||
ReduceDescription<4, 3, ReduceTensorOp::MIN, false, true>, // for MIN
|
||||
ReduceDescription<4, 4, ReduceTensorOp::MIN, false, true>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::MIN, false, true>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::MIN, false, true>,
|
||||
ReduceDescription<4, 3, ReduceTensorOp::MAX, false, true>, // for MAX
|
||||
ReduceDescription<4, 4, ReduceTensorOp::MAX, false, true>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::MAX, false, true>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::MAX, false, true>,
|
||||
ReduceDescription<4, 3, ReduceTensorOp::AMAX, false, true>, // for AMAX
|
||||
ReduceDescription<4, 4, ReduceTensorOp::AMAX, false, true>,
|
||||
ReduceDescription<4, 1, ReduceTensorOp::AMAX, false, true>,
|
||||
ReduceDescription<2, 1, ReduceTensorOp::AMAX, false, true>>;
|
||||
|
||||
template <typename DescriptionType>
|
||||
bool description_match(const DescriptionType& description,
|
||||
@@ -78,9 +82,8 @@ bool description_match(const DescriptionType& description,
|
||||
bool PropagateNan,
|
||||
bool UseIndex)
|
||||
{
|
||||
if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast<int>(ReduceOpId) ||
|
||||
description.PropagateNan_ != static_cast<int>(PropagateNan) ||
|
||||
description.UseIndex_ != static_cast<int>(UseIndex))
|
||||
if(description.Rank_ != Rank || description.ReduceOpId_ != ReduceOpId ||
|
||||
description.PropagateNan_ != PropagateNan || description.UseIndex_ != UseIndex)
|
||||
return (false);
|
||||
|
||||
if(DescriptionType::NumReduceDim_ != reduceDims.size())
|
||||
@@ -99,11 +102,10 @@ bool description_match(const DescriptionType& description,
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <index_t Rank, index_t NumReduceDim>
|
||||
static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduceDims)
|
||||
template <int Rank, int NumReduceDim>
|
||||
static inline std::array<int, Rank - NumReduceDim>
|
||||
get_invariant_dims(const std::array<int, NumReduceDim>& reduceDims)
|
||||
{
|
||||
assert(NumReduceDim == reduceDims.size());
|
||||
|
||||
int reduceFlag = 0;
|
||||
|
||||
// flag the bits for the reduceDims
|
||||
@@ -112,13 +114,15 @@ static inline std::vector<int> get_invariant_dims(const std::vector<int>& reduce
|
||||
reduceFlag |= 1 << reduceDims[i];
|
||||
};
|
||||
|
||||
std::vector<int> invariantDims;
|
||||
std::array<int, Rank - NumReduceDim> invariantDims;
|
||||
|
||||
// collect invariant dimensions
|
||||
int dim = 0;
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) == 0)
|
||||
{
|
||||
invariantDims.push_back(i);
|
||||
invariantDims[dim] = i;
|
||||
dim++;
|
||||
};
|
||||
|
||||
return invariantDims;
|
||||
@@ -137,7 +141,7 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
bool do_dumpout,
|
||||
bool time_kernel,
|
||||
const std::vector<size_t>& inLengths,
|
||||
const std::vector<int>& reduceDims,
|
||||
const std::array<int, NumReduceDim>& reduceDims,
|
||||
float alpha,
|
||||
float beta)
|
||||
{
|
||||
@@ -145,6 +149,8 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
using namespace ck::tensor_operation::device::instance;
|
||||
using ck::host_common::dumpBufferToFile;
|
||||
|
||||
constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
|
||||
|
||||
constexpr bool op_support_indices =
|
||||
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
|
||||
ReduceOpId == ReduceTensorOp::AMAX);
|
||||
@@ -279,28 +285,32 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
using DeviceReduceInstPtr0 =
|
||||
DeviceReducePtr<InElementwiseOperation, AccElementwiseOperation>;
|
||||
using DeviceReduceInstPtr =
|
||||
DeviceReducePtr<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>;
|
||||
|
||||
std::vector<DeviceReduceInstPtr0> reduce0_ptrs;
|
||||
std::vector<DeviceReduceInstPtr> reduce_ptrs;
|
||||
|
||||
add_device_reduce_instance_threadwise<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
UseIndex>(reduce0_ptrs);
|
||||
UseIndex>(reduce_ptrs);
|
||||
|
||||
add_device_reduce_instance_blockwise<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
UseIndex>(reduce0_ptrs);
|
||||
UseIndex>(reduce_ptrs);
|
||||
|
||||
if constexpr(use_atomic_add)
|
||||
{
|
||||
@@ -309,12 +319,14 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOpId,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
UseIndex>(reduce0_ptrs);
|
||||
UseIndex>(reduce_ptrs);
|
||||
}
|
||||
|
||||
if(reduce0_ptrs.empty())
|
||||
if(reduce_ptrs.empty())
|
||||
{
|
||||
throw std::runtime_error("Wrong! No device REDUCE instance found");
|
||||
};
|
||||
@@ -342,22 +354,22 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths;
|
||||
std::vector<ck::index_t> i_inStrides;
|
||||
std::vector<ck::index_t> i_outLengths;
|
||||
std::vector<ck::index_t> i_outStrides;
|
||||
std::array<index_t, Rank> arrInLengths;
|
||||
std::array<index_t, Rank> arrInStrides;
|
||||
std::array<index_t, NumOutDim> arrOutLengths;
|
||||
std::array<index_t, NumOutDim> arrOutStrides;
|
||||
|
||||
i_inLengths.assign(inLengths.begin(), inLengths.end());
|
||||
i_inStrides.assign(inStrides.begin(), inStrides.end());
|
||||
i_outLengths.assign(outLengths.begin(), outLengths.end());
|
||||
i_outStrides.assign(outStrides.begin(), outStrides.end());
|
||||
std::copy(inLengths.begin(), inLengths.end(), arrInLengths.begin());
|
||||
std::copy(inStrides.begin(), inStrides.end(), arrInStrides.begin());
|
||||
std::copy(outLengths.begin(), outLengths.end(), arrOutLengths.begin());
|
||||
std::copy(outStrides.begin(), outStrides.end(), arrOutStrides.begin());
|
||||
|
||||
for(auto& reduce_ptr : reduce0_ptrs)
|
||||
for(auto& reduce_ptr : reduce_ptrs)
|
||||
{
|
||||
auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
auto argument_ptr = reduce_ptr->MakeArgumentPointer(arrInLengths,
|
||||
arrInStrides,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
@@ -478,22 +490,25 @@ bool profile_reduce_impl(bool do_verification,
|
||||
descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex))
|
||||
return;
|
||||
|
||||
pass = pass &&
|
||||
profile_reduce_impl_impl<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
descType::Rank_,
|
||||
descType::NumReduceDim_,
|
||||
static_cast<ReduceTensorOp>(descType::ReduceOpId_),
|
||||
static_cast<bool>(descType::PropagateNan_),
|
||||
static_cast<bool>(descType::UseIndex_)>(do_verification,
|
||||
init_method,
|
||||
do_dumpout,
|
||||
time_kernel,
|
||||
inLengths,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta);
|
||||
std::array<ck::index_t, descType::NumReduceDim_> arrReduceDims;
|
||||
|
||||
std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin());
|
||||
|
||||
pass = pass && profile_reduce_impl_impl<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
descType::Rank_,
|
||||
descType::NumReduceDim_,
|
||||
static_cast<ReduceTensorOp>(descType::ReduceOpId_),
|
||||
descType::PropagateNan_,
|
||||
descType::UseIndex_>(do_verification,
|
||||
init_method,
|
||||
do_dumpout,
|
||||
time_kernel,
|
||||
inLengths,
|
||||
arrReduceDims,
|
||||
alpha,
|
||||
beta);
|
||||
|
||||
matched = true;
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user