Regulate reduction accumulator operations and Element-wise operations (#274)

* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces

* Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers

* Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers

* Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation

* Use struct-scope operator template instantiation for binary and unary element-wise operations

* Change a few more elementwise operations to use template for operator()

* Tiny correction in Normalize operator

* Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons

* Correction in some examples with regard to using ReduceAccDataType

* Use static_assert for UnaryDivide

* Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly

* Tiny fix with regard to SetWorkSpacePointer()

[ROCm/composable_kernel commit: 1f543bfa79]
This commit is contained in:
Qianfeng
2022-06-18 04:10:25 +08:00
committed by GitHub
parent 06d1305373
commit f92a3a4622
48 changed files with 891 additions and 837 deletions

View File

@@ -44,7 +44,7 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const final
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
{
assert(p_arg);
p_arg->p_workspace_ = p_workspace;

View File

@@ -557,11 +557,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float ave_time = 0;
using Add =
ck::tensor_operation::binary_element_wise::Add<CDataType, CDataType, CDataType>;
using Substract = ck::tensor_operation::binary_element_wise::
Substract<CDataType, CDataType, CDataType>;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
using Add = ck::tensor_operation::element_wise::Add;
using Subtract = ck::tensor_operation::element_wise::Subtract;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
@@ -573,19 +571,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
using GridwiseBinSubstract = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
using GridwiseBinSubtract = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Subtract,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector>;
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
CDataType,
CDataType,
CDataType,
@@ -593,14 +591,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CGridDesc_M,
CGridDesc_M,
Add>;
const auto substract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubstract,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Substract>;
const auto subtract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubtract,
CDataType,
CDataType,
CDataType,
CGridDesc_M,
CGridDesc_M,
CGridDesc_M,
Subtract>;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
@@ -653,7 +651,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config,
substract_kernel,
subtract_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -663,7 +661,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Substract{});
Subtract{});
ave_time +=
launch_and_time_kernel(stream_config,
@@ -764,7 +762,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(stream_config,
substract_kernel,
subtract_kernel,
dim3(grid_size),
dim3(BlockSize),
0,
@@ -774,7 +772,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
arg.c_grid_desc_m_,
Substract{});
Subtract{});
ave_time +=
launch_and_time_kernel(stream_config,

View File

@@ -35,14 +35,13 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
using IndexDataType = int32_t;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
@@ -178,13 +177,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
invariant_lowest_length_ = C;
reduce_lowest_length_ = window_spatial_lengths[1];
// TODO: is this correct?
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
{
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
in_element_op_ = InElementwiseOperation{divider};
acc_element_op_ = AccElementwiseOperation{divider};
}
int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
std::tie(in_element_op_, acc_element_op_) =
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
}
const InDataType* p_in_dev_;

View File

@@ -61,12 +61,9 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static constexpr bool out_type_compatible_with_atomic_op =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
static_assert(
!use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
"The OutDataType must support the atomic operation for using MultiBlock reduction");
static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value,
"The OutDataType must support the specified OutMemoryDataOperation!");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
"MultiBlock reduction can only be used when outputing index is not required");
@@ -349,7 +346,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock)
{
const auto identityVal =
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
ck::reduce::GetIdentityValueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);
const auto kernel_pre =
@@ -492,7 +489,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceMultiBlockAtomicAdd<" << BlockSize << ",";
str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";

View File

@@ -29,6 +29,7 @@
#include "reduction_operator.hpp"
#include "reduction_enums.hpp"
#include "element_wise_operation.hpp"
#include <tuple>
namespace ck {
@@ -37,77 +38,69 @@ namespace ck {
// The boolean member "indexable" are also provided in reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels.
template <typename T, ReduceTensorOp Op>
template <ReduceTensorOp Op>
struct reduce_binary_operator;
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::ADD>
template <>
struct reduce_binary_operator<ReduceTensorOp::ADD>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::MUL>
template <>
struct reduce_binary_operator<ReduceTensorOp::MUL>
{
using opType = reduce::Mul<T>;
using dataType = T;
using opType = reduce::Mul;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::MIN>
template <>
struct reduce_binary_operator<ReduceTensorOp::MIN>
{
using opType = reduce::Min<T>;
using dataType = T;
using opType = reduce::Min;
static constexpr bool indexable = true;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::MAX>
template <>
struct reduce_binary_operator<ReduceTensorOp::MAX>
{
using opType = reduce::Max<T>;
using dataType = T;
using opType = reduce::Max;
static constexpr bool indexable = true;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::AMAX>
template <>
struct reduce_binary_operator<ReduceTensorOp::AMAX>
{
using opType = reduce::AMax<T>;
using dataType = T;
using opType = reduce::AMax;
static constexpr bool indexable = true;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::AVG>
template <>
struct reduce_binary_operator<ReduceTensorOp::AVG>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::NORM1>
template <>
struct reduce_binary_operator<ReduceTensorOp::NORM1>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
template <>
struct reduce_binary_operator<ReduceTensorOp::NORM2>
{
using opType = reduce::Add<T>;
using dataType = T;
using opType = reduce::Add;
static constexpr bool indexable = false;
};
@@ -115,53 +108,101 @@ struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
// functor classes.
// The two unary functors are called before and afer the Reduction is executed respectively
template <typename T, ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
template <ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
struct reduce_unary_operator
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T, bool IsFirstReduce>
struct reduce_unary_operator<T, ReduceTensorOp::AVG, IsFirstReduce, true>
template <bool IsFirstReduce>
struct reduce_unary_operator<ReduceTensorOp::AVG, IsFirstReduce, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength});
};
};
template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp::NORM1, true, IsLastReduce>
template <bool IsLastReduce>
struct reduce_unary_operator<ReduceTensorOp::NORM1, true, IsLastReduce>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp::AMAX, true, IsLastReduce>
template <bool IsLastReduce>
struct reduce_unary_operator<ReduceTensorOp::AMAX, true, IsLastReduce>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, false>
template <>
struct reduce_unary_operator<ReduceTensorOp::NORM2, true, false>
{
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare;
using AccElementwiseOperation = tensor_operation::element_wise::PassThrough;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, true>
template <>
struct reduce_unary_operator<ReduceTensorOp::NORM2, true, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, false, true>
template <>
struct reduce_unary_operator<ReduceTensorOp::NORM2, false, true>
{
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
using InElementwiseOperation = tensor_operation::element_wise::PassThrough;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt;
static std::tuple<InElementwiseOperation, AccElementwiseOperation>
GetElementwiseOperator(int32_t reduceLength)
{
(void)reduceLength;
return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{});
};
};
} // end of namespace ck

View File

@@ -28,100 +28,189 @@
namespace ck {
namespace tensor_operation {
namespace binary_element_wise {
template <typename Y, typename X1, typename X2>
struct Add;
namespace element_wise {
template <>
struct Add<double, double, double>
struct Add
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
dst = src1 + src2;
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 + x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 + x1;
};
// Question: should bhalf_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp + x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
};
template <>
struct Add<float, float, float>
struct Subtract
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
dst = src1 + src2;
y = x0 - x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 - x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 - x1;
};
// Question: should bhalf_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp - x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
};
template <>
struct Add<half_t, half_t, half_t>
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
dst = src1 + src2;
}
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = static_cast<double>(alpha_) * x0 + static_cast<double>(beta_) * x1;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
};
float alpha_;
float beta_;
};
template <>
struct Add<bhalf_t, bhalf_t, bhalf_t>
struct AddRelu
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 + x2;
dst = ck::type_convert<bhalf_t>(y);
}
const float a = x0 + x1;
y = a > 0.0f ? a : 0.0f;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
const double a = x0 + x1;
y = a > 0.0 ? a : 0.0;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = a > static_cast<half_t>(0.0f) ? a : static_cast<half_t>(0.0f);
};
};
template <typename Y, typename X1, typename X2>
struct Substract;
template <>
struct Substract<double, double, double>
struct AddHardswish
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
operator()<float>(float& y, const float& x0, const float& x1) const
{
dst = src1 - src2;
}
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
y = c;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
double a = x0 + x1;
double b = a + 3.0;
double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
y = c;
};
// Question: should half_t be supported ?
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + 3.0f;
float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
y = c;
};
};
template <>
struct Substract<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 - src2;
}
};
} // namespace element_wise
template <>
struct Substract<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 - x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
} // namespace binary_element_wise
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,97 +1,13 @@
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
#include "unary_element_wise_operation.hpp"
#include "binary_element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
__host__ __device__ void operator()(float& y, const float& x) const { y = x; }
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; }
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x; }
};
struct Add
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
y = x0 + x1;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
// FIXME - Use float (acc type) bias in the future.
y = x0 + x1;
}
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
// FIXME - Let x0 be acc type
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
}
float alpha_;
float beta_;
};
struct AddRelu
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
const float a = x0 + x1;
y = a > 0 ? a : 0;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = a > 0 ? a : 0;
}
};
struct AddHardswish
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
struct AddReluAdd
{
__host__ __device__ constexpr void
@@ -167,204 +83,41 @@ struct Relu
struct Normalize
{
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
__host__ __device__ constexpr void operator()(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
template <typename T>
__host__ __device__ constexpr void operator()(
T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const;
template <>
__host__ __device__ constexpr void operator()<float>(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
{
using ck::math::sqrt;
float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta;
}
float epsilon_;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template <typename Y, typename X, bool HasDividing = false>
struct UnaryIdentic;
template <>
struct UnaryIdentic<float, float, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = x; };
};
template <>
struct UnaryIdentic<float, float, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(float& y, const float& x) const
{
y = x / type_convert<float>(divider_);
y = ((x - mean) / sqrt(variance + static_cast<float>(epsilon_))) * gamma + beta;
};
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<half_t, half_t, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<double, double, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = x; };
};
template <>
struct UnaryIdentic<double, double, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(double& y, const double& x) const
template <>
__host__ __device__ constexpr void operator()<double>(double& y,
const double& x,
const double& mean,
const double& mean_square,
const double& gamma,
const double& beta) const
{
y = x / type_convert<double>(divider_);
using ck::math::sqrt;
double variance = mean_square - (mean * mean);
y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
};
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<int32_t, int32_t, false>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; };
};
template <>
struct UnaryIdentic<int32_t, int32_t, true>
{
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; };
int32_t divider_ = 1;
};
template <>
struct UnaryIdentic<int8_t, int8_t, false>
{
__host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; };
};
template <typename Y, typename X, bool HasDividing = false>
struct UnarySquare;
template <>
struct UnarySquare<float, float, false>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = x * x; };
};
template <>
struct UnarySquare<float, float, true>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(float& y, const float& x) const
{
y = x * x / type_convert<float>(divider_);
};
int32_t divider_ = 1;
};
template <>
struct UnarySquare<double, double, false>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = x * x; };
};
template <>
struct UnarySquare<double, double, true>
{
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
__host__ __device__ void operator()(double& y, const double& x) const
{
y = x * x / type_convert<double>(divider_);
};
int32_t divider_ = 1;
};
template <typename Y, typename X>
struct UnaryAbs;
template <>
struct UnaryAbs<float, float>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
};
template <>
struct UnaryAbs<half_t, half_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
};
template <>
struct UnaryAbs<double, double>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
};
template <>
struct UnaryAbs<int8_t, int8_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
};
template <typename Y, typename X>
struct UnarySqrt;
template <>
struct UnarySqrt<float, float>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
};
template <>
struct UnarySqrt<double, double>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
double epsilon_;
};
template <typename Y, typename X>

View File

@@ -0,0 +1,80 @@
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, bhalf_t>::value ||
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x;
};
};
struct UnaryDivide
{
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
int32_t divider_ = 1;
};
struct UnarySquare
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value,
"Data type is not supported by this operation!");
y = x * x;
};
};
struct UnaryAbs
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = ck::math::abs(x);
};
};
struct UnarySqrt
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value,
"Data type is not supported by this operation!");
y = ck::math::sqrt(x);
};
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck

View File

@@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(

View File

@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation,
PropagateNan>;
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op;
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());

View File

@@ -927,7 +927,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetIdentityValue();
const auto d_zeroVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });

View File

@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>;
// Global write Gemm shuffle + reduction
const auto d_identityVal = DReduceOperation::GetIdentityValue();
const auto d_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_identityVal; });

View File

@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
{
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
constexpr auto I0 = Number<0>{};

View File

@@ -28,6 +28,7 @@
#include "config.hpp"
#include "data_type.hpp"
#include "type.hpp"
namespace ck {
@@ -54,64 +55,92 @@ namespace reduce {
// accumulated index also need be
// changed.
template <class T>
struct Add
{
using dataType = T;
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
};
template <class T>
struct Mul
{
using dataType = T;
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(1.0f);
};
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
};
template <class T>
struct Max
{
using dataType = T;
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
};
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
{
a = b;
@@ -120,28 +149,41 @@ struct Max
}
};
template <class T>
struct Min
{
using dataType = T;
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Max();
};
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b)
{
a = b;
@@ -150,28 +192,41 @@ struct Min
}
};
template <class T>
struct AMax
{
using dataType = T;
template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b)
{
a = b;
@@ -181,7 +236,7 @@ struct AMax
};
template <typename T>
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
T result = ck::type_convert<T>(0.0f);
@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation
return (result);
};
template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
{
static constexpr bool value = false;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value;
};
}; // end of namespace reduce
} // end of namespace ck