mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Regulate reduction accumulator operations and Element-wise operations (#274)
* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces * Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers * Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers * Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation * Use struct-scope operator template instantiation for binary and unary element-wise operations * Change a few more elementwise operations to use template for operator() * Tiny correction in Normalize operator * Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons * Correction in some examples with regard to using ReduceAccDataType * Use static_assert for UnaryDivide * Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly * Tiny fix with regard to SetWorkSpacePointer()
This commit is contained in:
@@ -44,7 +44,7 @@ struct BaseOperator
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const final
|
||||
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
|
||||
{
|
||||
assert(p_arg);
|
||||
p_arg->p_workspace_ = p_workspace;
|
||||
|
||||
@@ -557,11 +557,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
using Add =
|
||||
ck::tensor_operation::binary_element_wise::Add<CDataType, CDataType, CDataType>;
|
||||
using Substract = ck::tensor_operation::binary_element_wise::
|
||||
Substract<CDataType, CDataType, CDataType>;
|
||||
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Subtract = ck::tensor_operation::element_wise::Subtract;
|
||||
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
@@ -573,19 +571,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Substract,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
|
||||
using GridwiseBinSubtract = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Subtract,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
@@ -593,14 +591,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Add>;
|
||||
const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Substract>;
|
||||
const auto subtract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubtract,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Subtract>;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
@@ -653,7 +651,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
substract_kernel,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -663,7 +661,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Substract{});
|
||||
Subtract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
@@ -764,7 +762,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
substract_kernel,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -774,7 +772,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Substract{});
|
||||
Subtract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -35,14 +35,13 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
static constexpr index_t InSrcOutDstVectorDim =
|
||||
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
|
||||
@@ -178,13 +177,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
invariant_lowest_length_ = C;
|
||||
reduce_lowest_length_ = window_spatial_lengths[1];
|
||||
|
||||
// TODO: is this correct?
|
||||
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
|
||||
{
|
||||
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
in_element_op_ = InElementwiseOperation{divider};
|
||||
acc_element_op_ = AccElementwiseOperation{divider};
|
||||
}
|
||||
int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
|
||||
std::tie(in_element_op_, acc_element_op_) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
|
||||
}
|
||||
|
||||
const InDataType* p_in_dev_;
|
||||
|
||||
@@ -61,12 +61,9 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
static constexpr bool use_multiblock =
|
||||
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
|
||||
|
||||
static constexpr bool out_type_compatible_with_atomic_op =
|
||||
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
|
||||
|
||||
static_assert(
|
||||
!use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
|
||||
"The OutDataType must support the atomic operation for using MultiBlock reduction");
|
||||
static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
|
||||
OutDataType>::value,
|
||||
"The OutDataType must support the specified OutMemoryDataOperation!");
|
||||
|
||||
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
|
||||
"MultiBlock reduction can only be used when outputing index is not required");
|
||||
@@ -349,7 +346,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
const auto identityVal =
|
||||
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
|
||||
ck::reduce::GetIdentityValueForInMemoryDataOperation<OutDataType>(
|
||||
OutMemoryDataOperation);
|
||||
|
||||
const auto kernel_pre =
|
||||
@@ -492,7 +489,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceMultiBlockAtomicAdd<" << BlockSize << ",";
|
||||
str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include <tuple>
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -37,77 +38,69 @@ namespace ck {
|
||||
// The boolean member "indexable" are also provided in reduce_binary_operactor for
|
||||
// easier checking by the upper-layer codes in the kernels.
|
||||
|
||||
template <typename T, ReduceTensorOp Op>
|
||||
template <ReduceTensorOp Op>
|
||||
struct reduce_binary_operator;
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::ADD>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::ADD>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::Add;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::MUL>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::MUL>
|
||||
{
|
||||
using opType = reduce::Mul<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::Mul;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::MIN>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::MIN>
|
||||
{
|
||||
using opType = reduce::Min<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::Min;
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::MAX>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::MAX>
|
||||
{
|
||||
using opType = reduce::Max<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::Max;
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::AMAX>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::AMAX>
|
||||
{
|
||||
using opType = reduce::AMax<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::AMax;
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::AVG>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::AVG>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::Add;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::NORM1>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::NORM1>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::Add;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
|
||||
template <>
|
||||
struct reduce_binary_operator<ReduceTensorOp::NORM2>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
using opType = reduce::Add;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
@@ -115,53 +108,101 @@ struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
|
||||
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
|
||||
// functor classes.
|
||||
// The two unary functors are called before and afer the Reduction is executed respectively
|
||||
template <typename T, ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
|
||||
template <ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
|
||||
struct reduce_unary_operator
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
|
||||
GetElementwiseOperator(int32_t reduceLength)
|
||||
{
|
||||
(void)reduceLength;
|
||||
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T, bool IsFirstReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::AVG, IsFirstReduce, true>
|
||||
template <bool IsFirstReduce>
|
||||
struct reduce_unary_operator<ReduceTensorOp::AVG, IsFirstReduce, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
|
||||
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
|
||||
GetElementwiseOperator(int32_t reduceLength)
|
||||
{
|
||||
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T, bool IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM1, true, IsLastReduce>
|
||||
template <bool IsLastReduce>
|
||||
struct reduce_unary_operator<ReduceTensorOp::NORM1, true, IsLastReduce>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
|
||||
GetElementwiseOperator(int32_t reduceLength)
|
||||
{
|
||||
(void)reduceLength;
|
||||
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T, bool IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::AMAX, true, IsLastReduce>
|
||||
template <bool IsLastReduce>
|
||||
struct reduce_unary_operator<ReduceTensorOp::AMAX, true, IsLastReduce>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
|
||||
GetElementwiseOperator(int32_t reduceLength)
|
||||
{
|
||||
(void)reduceLength;
|
||||
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, false>
|
||||
template <>
|
||||
struct reduce_unary_operator<ReduceTensorOp::NORM2, true, false>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
|
||||
GetElementwiseOperator(int32_t reduceLength)
|
||||
{
|
||||
(void)reduceLength;
|
||||
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, true>
|
||||
template <>
|
||||
struct reduce_unary_operator<ReduceTensorOp::NORM2, true, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt;
|
||||
|
||||
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
|
||||
GetElementwiseOperator(int32_t reduceLength)
|
||||
{
|
||||
(void)reduceLength;
|
||||
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, false, true>
|
||||
template <>
|
||||
struct reduce_unary_operator<ReduceTensorOp::NORM2, false, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
|
||||
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt;
|
||||
|
||||
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
|
||||
GetElementwiseOperator(int32_t reduceLength)
|
||||
{
|
||||
(void)reduceLength;
|
||||
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
|
||||
};
|
||||
};
|
||||
|
||||
} // end of namespace ck
|
||||
|
||||
Reference in New Issue
Block a user