mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Regulate reduction accumulator operations and Element-wise operations (#274)
* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces * Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers * Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers * Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation * Use struct-scope operator template instantiation for binary and unary element-wise operations * Change a few more elementwise operations to use template for operator() * Tiny correction in Normalize operator * Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons * Correction in some examples with regard to using ReduceAccDataType * Use static_assert for UnaryDivide * Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly * Tiny fix with regard to SetWorkSpacePointer()
This commit is contained in:
@@ -33,11 +33,11 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
|
||||
constexpr bool PropagateNan = true;
|
||||
constexpr bool OutputIndex = false;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using DeviceReduceInstance = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
@@ -247,6 +247,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
DeviceMem out_index_dev(indicesSizeInBytes);
|
||||
|
||||
InElementwiseOperation in_elementwise_op;
|
||||
AccElementwiseOperation acc_elementwise_op;
|
||||
|
||||
std::tie(in_elementwise_op, acc_elementwise_op) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
ReductionHost<InDataType,
|
||||
@@ -261,8 +268,13 @@ int main(int argc, char* argv[])
|
||||
OutputIndex>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(
|
||||
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
|
||||
hostReduce.Run(alpha,
|
||||
in.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
out_indices_ref.mData.data(),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths;
|
||||
@@ -277,20 +289,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce = DeviceReduceInstance{};
|
||||
|
||||
auto argument_ptr = reduce.MakeArgumentPointer(
|
||||
i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_index_dev.GetDeviceBuffer(),
|
||||
InElementwiseOperation{static_cast<int32_t>(reduce_total_length)},
|
||||
AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
|
||||
auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_index_dev.GetDeviceBuffer(),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
|
||||
if(!reduce.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
@@ -31,13 +31,13 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
|
||||
constexpr bool PropagateNan = true;
|
||||
constexpr bool OutputIndex = false;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<AccDataType, AccDataType>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType,
|
||||
AccDataType,
|
||||
@@ -184,6 +184,13 @@ int main(int argc, char* argv[])
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
InElementwiseOperation in_elementwise_op;
|
||||
AccElementwiseOperation acc_elementwise_op;
|
||||
|
||||
std::tie(in_elementwise_op, acc_elementwise_op) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
if(do_verify)
|
||||
{
|
||||
ReductionHost<InOutDataType,
|
||||
@@ -198,7 +205,13 @@ int main(int argc, char* argv[])
|
||||
OutputIndex>
|
||||
hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(alpha, in_1.mData.data(), beta, out_ref.mData.data(), nullptr);
|
||||
hostReduce.Run(alpha,
|
||||
in_1.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths_1;
|
||||
@@ -217,20 +230,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce_1 = DeviceReduceInstance_1{};
|
||||
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(
|
||||
i_inLengths_1,
|
||||
i_inStrides_1,
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
InElementwiseOperation{static_cast<int32_t>(reduce_total_length)},
|
||||
PassThroughOp{});
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1,
|
||||
i_inStrides_1,
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
PassThroughOp{});
|
||||
|
||||
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
|
||||
{
|
||||
@@ -243,20 +255,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto reduce_2 = DeviceReduceInstance_2{};
|
||||
|
||||
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
PassThroughOp{},
|
||||
AccElementwiseOperation{static_cast<int32_t>(reduce_total_length)});
|
||||
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
PassThroughOp{},
|
||||
acc_elementwise_op);
|
||||
|
||||
if(!reduce_2.IsSupportedArgument(argument_ptr_2.get()))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user