mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
Regulate reduction accumulator operations and Element-wise operations (#274)
* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces
* Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers
* Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers
* Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation
* Use struct-scope operator template instantiation for binary and unary element-wise operations
* Change a few more elementwise operations to use template for operator()
* Tiny correction in Normalize operator
* Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons
* Correction in some examples with regard to using ReduceAccDataType
* Use static_assert for UnaryDivide
* Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly
* Tiny fix with regard to SetWorkSpacePointer()
[ROCm/composable_kernel commit: 1f543bfa79]
This commit is contained in:
@@ -20,8 +20,8 @@ namespace device_gemm_instance {
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using Identity = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Identity, Identity>;
|
||||
|
||||
@@ -128,17 +128,15 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ReduceOp = ck::reduce::Add<float>;
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
@@ -170,8 +168,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
float d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
float d0_acc = d0_reduce_op.GetIdentityValue<float>();
|
||||
float d1_acc = d1_reduce_op.GetIdentityValue<float>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
|
||||
@@ -20,9 +20,9 @@ namespace device_gemm_instance {
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using Div = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using Identity = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
@@ -136,20 +136,18 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using C1ElementOp = PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ReduceOp = ck::reduce::Add<float>;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>;
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using C1ElementOp = PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
@@ -196,15 +194,15 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val = 0;
|
||||
ReduceAccDataType d1_val = 0;
|
||||
ReduceAccDataType d0_val;
|
||||
ReduceAccDataType d1_val;
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
|
||||
@@ -20,9 +20,9 @@ namespace device_gemm_instance {
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using Div = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using Identity = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
@@ -123,18 +123,16 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ReduceOp = ck::reduce::Add<float>;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>;
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
@@ -167,15 +165,15 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val = 0;
|
||||
ReduceAccDataType d1_val = 0;
|
||||
ReduceAccDataType d0_val;
|
||||
ReduceAccDataType d1_val;
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
|
||||
@@ -261,13 +261,18 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
|
||||
InElementwiseOperation in_elementwise_op;
|
||||
AccElementwiseOperation acc_elementwise_op;
|
||||
|
||||
std::tie(in_elementwise_op, acc_elementwise_op) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
using DeviceReduceInstPtr0 =
|
||||
DeviceReducePtr<InElementwiseOperation, AccElementwiseOperation>;
|
||||
@@ -323,8 +328,13 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
OutputIndex>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(
|
||||
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
|
||||
hostReduce.Run(alpha,
|
||||
in.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
out_indices_ref.mData.data(),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths;
|
||||
@@ -339,10 +349,6 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
|
||||
for(auto& reduce_ptr : reduce0_ptrs)
|
||||
{
|
||||
|
||||
InElementwiseOperation in_elementwise_op(static_cast<int32_t>(reduce_total_length));
|
||||
AccElementwiseOperation acc_elementwise_op(static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
|
||||
Reference in New Issue
Block a user