mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Regulate reduction accumulator operations and Element-wise operations (#274)
* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces
* Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers
* Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers
* Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation
* Use struct-scope operator template instantiation for binary and unary element-wise operations
* Change a few more elementwise operations to use template for operator()
* Tiny correction in Normalize operator
* Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons
* Correction in some examples with regard to using ReduceAccDataType
* Use static_assert for UnaryDivide
* Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly
* Tiny fix with regard to SetWorkSpacePointer()
[ROCm/composable_kernel commit: 1f543bfa79]
This commit is contained in:
@@ -33,11 +33,11 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
|
||||
constexpr bool PropagateNan = true;
|
||||
constexpr bool OutputIndex = false;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using DeviceReduceInstance = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
@@ -247,6 +247,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
DeviceMem out_index_dev(indicesSizeInBytes);
|
||||
|
||||
InElementwiseOperation in_elementwise_op;
|
||||
AccElementwiseOperation acc_elementwise_op;
|
||||
|
||||
std::tie(in_elementwise_op, acc_elementwise_op) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
ReductionHost<InDataType,
|
||||
@@ -261,8 +268,13 @@ int main(int argc, char* argv[])
|
||||
OutputIndex>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(
|
||||
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
|
||||
hostReduce.Run(alpha,
|
||||
in.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
out_indices_ref.mData.data(),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths;
|
||||
@@ -277,20 +289,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce = DeviceReduceInstance{};
|
||||
|
||||
auto argument_ptr = reduce.MakeArgumentPointer(
|
||||
i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_index_dev.GetDeviceBuffer(),
|
||||
InElementwiseOperation{static_cast<int32_t>(reduce_total_length)},
|
||||
AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
|
||||
auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_index_dev.GetDeviceBuffer(),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
|
||||
if(!reduce.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
@@ -31,13 +31,13 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
|
||||
constexpr bool PropagateNan = true;
|
||||
constexpr bool OutputIndex = false;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<AccDataType, AccDataType>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType,
|
||||
AccDataType,
|
||||
@@ -184,6 +184,13 @@ int main(int argc, char* argv[])
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
InElementwiseOperation in_elementwise_op;
|
||||
AccElementwiseOperation acc_elementwise_op;
|
||||
|
||||
std::tie(in_elementwise_op, acc_elementwise_op) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
if(do_verify)
|
||||
{
|
||||
ReductionHost<InOutDataType,
|
||||
@@ -198,7 +205,13 @@ int main(int argc, char* argv[])
|
||||
OutputIndex>
|
||||
hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(alpha, in_1.mData.data(), beta, out_ref.mData.data(), nullptr);
|
||||
hostReduce.Run(alpha,
|
||||
in_1.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths_1;
|
||||
@@ -217,20 +230,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce_1 = DeviceReduceInstance_1{};
|
||||
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(
|
||||
i_inLengths_1,
|
||||
i_inStrides_1,
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
InElementwiseOperation{static_cast<int32_t>(reduce_total_length)},
|
||||
PassThroughOp{});
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1,
|
||||
i_inStrides_1,
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
PassThroughOp{});
|
||||
|
||||
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
|
||||
{
|
||||
@@ -243,20 +255,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce_2 = DeviceReduceInstance_2{};
|
||||
|
||||
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
PassThroughOp{},
|
||||
AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
|
||||
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
PassThroughOp{},
|
||||
acc_elementwise_op);
|
||||
|
||||
if(!reduce_2.IsSupportedArgument(argument_ptr_2.get()))
|
||||
{
|
||||
|
||||
@@ -31,16 +31,15 @@ static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
const std::array<ck::index_t, 2>& in_left_pads,
|
||||
const std::array<ck::index_t, 2>& /*in_right_pads*/)
|
||||
{
|
||||
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
|
||||
using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using InElementwiseOperation = typename ck::
|
||||
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation = typename ck::
|
||||
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
|
||||
|
||||
const InElementwiseOperation in_elementwise_op(divider);
|
||||
const AccElementwiseOperation acc_elementwise_op(divider);
|
||||
auto elementwise_ops =
|
||||
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
|
||||
|
||||
auto in_elementwise_op = std::get<0>(elementwise_ops);
|
||||
auto acc_elementwise_op = std::get<1>(elementwise_ops);
|
||||
|
||||
if constexpr(!OutputIndex)
|
||||
{
|
||||
@@ -48,7 +47,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOperation::GetIdentityValue();
|
||||
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
{
|
||||
@@ -86,7 +85,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOperation::GetIdentityValue();
|
||||
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
IndexDataType accuIndex = 0;
|
||||
|
||||
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
|
||||
@@ -41,9 +41,8 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DsReduceOp = ck::Tuple<ck::reduce::Max<ReduceAccDataType>>;
|
||||
using DsElementOp = ck::Tuple<
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>>;
|
||||
using DsReduceOp = ck::Tuple<ck::reduce::Max>;
|
||||
using DsElementOp = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
|
||||
|
||||
@@ -236,10 +235,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
d_reduce_op(d_acc, c_m_n_host_result(m, n));
|
||||
{
|
||||
ReduceAccDataType curr_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
d_reduce_op(d_acc, curr_val);
|
||||
};
|
||||
|
||||
d_m_host_result(m) = d_acc;
|
||||
}
|
||||
|
||||
@@ -41,18 +41,15 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<ReduceAccDataType>;
|
||||
using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
|
||||
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using UnaryDivElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -261,15 +258,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val = 0;
|
||||
ReduceAccDataType d1_val = 0;
|
||||
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val;
|
||||
ReduceAccDataType d1_val;
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
|
||||
@@ -39,16 +39,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<ReduceAccDataType>;
|
||||
using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
|
||||
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -259,14 +257,15 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
float d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float c_val = ck::type_convert<float>(c_g_m_n_host_result(batch, m, n));
|
||||
float d0_val = 0;
|
||||
float d1_val = 0;
|
||||
auto c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
|
||||
ReduceAccDataType d0_val;
|
||||
ReduceAccDataType d1_val;
|
||||
|
||||
UnaryIdenticElementOp{}(d0_val, c_val);
|
||||
UnarySquareElementOp{}(d1_val, c_val);
|
||||
|
||||
@@ -42,8 +42,7 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
|
||||
@@ -17,8 +17,7 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
|
||||
@@ -42,8 +42,7 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
|
||||
@@ -42,8 +42,7 @@ using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::binary_element_wise::
|
||||
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
|
||||
@@ -48,17 +48,14 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::Relu;
|
||||
using C1ElementOp = PassThrough;
|
||||
using ReduceSumOp = ck::reduce::Add<ReduceAccDataType>;
|
||||
using ReduceSumOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using UnaryDivElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -181,8 +178,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
auto reduceSumOpInst = ReduceSumOp{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
AccDataType mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
AccDataType square_mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
auto mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>();
|
||||
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
@@ -207,7 +204,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType out_acc = 0;
|
||||
layerNormInst(out_acc, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
|
||||
layerNormInst(out_acc,
|
||||
static_cast<AccDataType>(c_m_n(m, n)),
|
||||
static_cast<AccDataType>(mean_m(m)),
|
||||
static_cast<AccDataType>(meanSquare_m(m)),
|
||||
static_cast<AccDataType>(gamma_n(n)),
|
||||
static_cast<AccDataType>(beta_n(n)));
|
||||
out_m_n(m, n) = static_cast<DDataType>(out_acc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,17 +44,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceSumOp = ck::reduce::Add<ReduceAccDataType>;
|
||||
using ReduceSumOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using UnaryDivElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -156,13 +153,14 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
auto reduceSumOpInst = ReduceSumOp{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
float square_mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
auto mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
|
||||
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
|
||||
ReduceAccDataType square_c_val = 0;
|
||||
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
|
||||
auto square_c_val = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
UnarySquareElementOp{}(square_c_val, c_val);
|
||||
|
||||
reduceSumOpInst(mean_acc, c_val);
|
||||
@@ -182,7 +180,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float out_f32 = 0;
|
||||
layerNormInst(out_f32, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
|
||||
layerNormInst(out_f32,
|
||||
static_cast<float>(c_m_n(m, n)),
|
||||
static_cast<float>(mean_m(m)),
|
||||
static_cast<float>(meanSquare_m(m)),
|
||||
static_cast<float>(gamma_n(n)),
|
||||
static_cast<float>(beta_n(n)));
|
||||
out_m_n(m, n) = static_cast<out_type>(out_f32);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user