mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Regulate reduction accumulator operations and Element-wise operations (#274)
* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces
* Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers
* Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers
* Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation
* Use struct-scope operator template instantiation for binary and unary element-wise operations
* Change a few more elementwise operations to use template for operator()
* Tiny correction in Normalize operator
* Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons
* Correction in some examples with regard to using ReduceAccDataType
* Use static_assert for UnaryDivide
* Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly
* Tiny fix with regard to SetWorkSpacePointer()
[ROCm/composable_kernel commit: 1f543bfa79]
This commit is contained in:
@@ -174,15 +174,18 @@ struct ReductionHost
|
||||
const InDataType* in_data,
|
||||
float beta,
|
||||
OutDataType* out_data,
|
||||
IndexDataType* out_indices)
|
||||
IndexDataType* out_indices,
|
||||
InElementwiseOperation in_elementwise_op,
|
||||
AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
|
||||
RunImpl_with_index(
|
||||
alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunImpl_no_index(alpha, in_data, beta, out_data);
|
||||
RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -190,7 +193,9 @@ struct ReductionHost
|
||||
const InDataType* in_data,
|
||||
float beta,
|
||||
OutDataType* out_data,
|
||||
IndexDataType* out_indices)
|
||||
IndexDataType* out_indices,
|
||||
InElementwiseOperation in_elementwise_op,
|
||||
AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
using ck::float_equal_one;
|
||||
using ck::float_equal_zero;
|
||||
@@ -200,12 +205,10 @@ struct ReductionHost
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
InElementwiseOperation in_elementwise_op(divider);
|
||||
AccElementwiseOperation acc_elementwise_op(divider);
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
AccDataType accuVal = ReduceOperation::GetIdentityValue();
|
||||
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
IndexDataType accuIndex = 0;
|
||||
|
||||
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
|
||||
@@ -236,7 +239,7 @@ struct ReductionHost
|
||||
else
|
||||
{
|
||||
auto thread_reduce_func = [&](auto invariant_index) {
|
||||
AccDataType accuVal = ReduceOperation::GetIdentityValue();
|
||||
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
IndexDataType accuIndex = 0;
|
||||
|
||||
auto offset_invariant =
|
||||
@@ -297,7 +300,12 @@ struct ReductionHost
|
||||
};
|
||||
};
|
||||
|
||||
void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data)
|
||||
void RunImpl_no_index(float alpha,
|
||||
const InDataType* in_data,
|
||||
float beta,
|
||||
OutDataType* out_data,
|
||||
InElementwiseOperation in_elementwise_op,
|
||||
AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
using ck::float_equal_one;
|
||||
using ck::float_equal_zero;
|
||||
@@ -306,12 +314,9 @@ struct ReductionHost
|
||||
using Accumulation =
|
||||
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
InElementwiseOperation in_elementwise_op(divider);
|
||||
AccElementwiseOperation acc_elementwise_op(divider);
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
AccDataType accuVal = ReduceOperation::GetIdentityValue();
|
||||
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
for(const auto& reduce_index : reduce_dim_indexes)
|
||||
{
|
||||
@@ -338,7 +343,7 @@ struct ReductionHost
|
||||
else
|
||||
{
|
||||
auto thread_reduce_func = [&](auto invariant_index) {
|
||||
AccDataType accuVal = ReduceOperation::GetIdentityValue();
|
||||
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
auto offset_invariant =
|
||||
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
|
||||
|
||||
@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
}
|
||||
}
|
||||
|
||||
float v_in;
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
|
||||
arg.in_element_op_(v_acc, v_acc);
|
||||
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ncw,
|
||||
|
||||
@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
arg.a_element_op_(a, arg.a_m_k_(m, k));
|
||||
arg.b_element_op_(b, arg.b_k_n_(k, n));
|
||||
arg.a_element_op_(a, static_cast<AccDataType>(arg.a_m_k_(m, k)));
|
||||
arg.b_element_op_(b, static_cast<AccDataType>(arg.b_k_n_(k, n)));
|
||||
acc += a * b;
|
||||
}
|
||||
|
||||
|
||||
@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp ReduceOpId>
|
||||
template <ReduceTensorOp ReduceOpId>
|
||||
using deviceReduceBlockWisePtrType = DeviceReducePtr<
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
@@ -75,14 +75,13 @@ template <typename InDataType,
|
||||
bool PropagateNan,
|
||||
bool UseIndex>
|
||||
void add_device_reduce_instance_blockwise(
|
||||
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
|
||||
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>>& device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
|
||||
@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise(
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
|
||||
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise(
|
||||
Rank, \
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_blockwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<DeviceReducePtr< \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_blockwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
|
||||
@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp ReduceOperation>
|
||||
using deviceReduceMultiBlockAtomicAddPtrType =
|
||||
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
|
||||
InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
|
||||
AccElementwiseOperation>;
|
||||
template <ReduceTensorOp ReduceOperation>
|
||||
using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr<
|
||||
typename reduce_unary_operator<ReduceOperation, true, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<ReduceOperation, true, true>::AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
@@ -77,15 +75,13 @@ template <typename InDataType,
|
||||
bool PropagateNan,
|
||||
bool UseIndex>
|
||||
void add_device_reduce_instance_multiblock_atomic_add(
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>&
|
||||
device_op_instances)
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>>& device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
|
||||
@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
|
||||
device_op_instances)
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
Rank, \
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<DeviceReducePtr< \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
|
||||
@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp ReduceOpId>
|
||||
template <ReduceTensorOp ReduceOpId>
|
||||
using deviceReduceThreadWisePtrType = DeviceReducePtr<
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
@@ -61,14 +61,13 @@ template <typename InDataType,
|
||||
bool PropagateNan,
|
||||
bool UseIndex>
|
||||
void add_device_reduce_instance_threadwise(
|
||||
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
|
||||
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>>& device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
|
||||
@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise(
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
|
||||
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_THREADWISE_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise(
|
||||
Rank, \
|
||||
NumReduceDim)
|
||||
|
||||
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_threadwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<DeviceReducePtr< \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
|
||||
extern template void add_device_reduce_instance_threadwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
NumReduceDim, \
|
||||
ReduceOpId, \
|
||||
PropagateNan, \
|
||||
UseIndex>( \
|
||||
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_THREADWISE_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
|
||||
|
||||
Reference in New Issue
Block a user