mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Reduction external API and client examples (#493)
* Change to the DeviceReduce base class template to include all problem description information * Add external api for reduction * Add client example to test the reduction external api * Spelling correction * Re-implement the host_reduction to follow the DeviceReduce base API format * Change the reduce profiler to call the external API for collecting device instances * Rename reduce client example directory from 08_reduce to 12_reduce * Remove (void) before the functional call * Tiny update in reduce client example * Tiny update in profile_reduce_impl.hpp * Rename the reduce client example directory Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -12,13 +12,13 @@
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/utility/host_reduction.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
@@ -97,8 +97,8 @@ int main(int argc, char* argv[])
|
||||
// const std::array<int, 3> invariantDims_2 = {0, 1, 2};
|
||||
|
||||
// used by the host reduction
|
||||
const std::array<int, 2> reduceDims = {3, 4};
|
||||
const std::array<int, 3> invariantDims = {0, 1, 2};
|
||||
const std::array<int, 2> reduceDims = {3, 4};
|
||||
// const std::array<int, 3> invariantDims = {0, 1, 2};
|
||||
|
||||
const std::vector<size_t> inLengths_1 = {64, 320, 80, 4, 128};
|
||||
|
||||
@@ -191,29 +191,6 @@ int main(int argc, char* argv[])
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
if(do_verify)
|
||||
{
|
||||
ReductionHost<InOutDataType,
|
||||
AccDataType,
|
||||
InOutDataType,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
5, // Rank
|
||||
2, // NumReduceDim
|
||||
PropagateNan,
|
||||
OutputIndex>
|
||||
hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(alpha,
|
||||
in_1.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::array<index_t, 5> arrInLengths_1;
|
||||
std::array<index_t, 5> arrInStrides_1;
|
||||
std::array<index_t, 4> arrInLengths_2;
|
||||
@@ -228,6 +205,48 @@ int main(int argc, char* argv[])
|
||||
ck::ranges::copy(outLengths, arrOutLengths.begin());
|
||||
ck::ranges::copy(outStrides, arrOutStrides.begin());
|
||||
|
||||
if(do_verify)
|
||||
{
|
||||
using ReferenceReduceInstance =
|
||||
ck::tensor_operation::host::ReferenceReduce<InOutDataType,
|
||||
AccDataType,
|
||||
InOutDataType,
|
||||
5,
|
||||
2,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
OutputIndex>;
|
||||
|
||||
auto reduce_ref = ReferenceReduceInstance{};
|
||||
|
||||
auto argument_ptr_ref = reduce_ref.MakeArgumentPointer(arrInLengths_1,
|
||||
arrInStrides_1,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_1.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
|
||||
if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters not supported by the reduce reference, exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
auto invoker_ptr_ref = reduce_ref.MakeInvokerPointer();
|
||||
|
||||
invoker_ptr_ref->Run(argument_ptr_ref.get());
|
||||
};
|
||||
|
||||
auto reduce_1 = DeviceReduceInstance_1{};
|
||||
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(arrInLengths_1,
|
||||
@@ -246,9 +265,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
|
||||
<< std::endl;
|
||||
std::cout << "The runtime parameters seems supported by the DeviceReduce instance, exiting!"
|
||||
<< std::endl;
|
||||
};
|
||||
|
||||
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
|
||||
|
||||
Reference in New Issue
Block a user