mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Pr82 followup (#115)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny change in script/profile_reduce_with_index.sh * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "device_reduce_blockwise.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "host_generic_reduction.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
|
||||
@@ -28,8 +29,8 @@ using kInDataType = ck::half_t;
|
||||
using kOutDataType = ck::half_t;
|
||||
using kAccDataType = float;
|
||||
|
||||
constexpr int Rank = 4;
|
||||
using ReduceDims_ = ck::Sequence<0, 1, 2>;
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
constexpr ReduceTensorOp_t ReduceOpId = ReduceTensorOp_t::NORM2;
|
||||
constexpr NanPropagation_t NanOpt = NanPropagation_t::PROPAGATE_NAN;
|
||||
@@ -46,7 +47,7 @@ using DeviceReduceInstance = DeviceReduceBlockWise<kInDataType,
|
||||
kAccDataType,
|
||||
kOutDataType,
|
||||
Rank,
|
||||
ReduceDims_,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
@@ -192,39 +193,13 @@ class SimpleAppArgs
|
||||
};
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
static std::vector<int> get_reduce_dims()
|
||||
{
|
||||
std::vector<int> resDims;
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}([&](auto i) { resDims.push_back(ReduceDims::At(i)); });
|
||||
|
||||
return (resDims);
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
static std::vector<int> get_invariant_dims()
|
||||
{
|
||||
std::vector<int> resDims;
|
||||
unsigned int incFlag = 0;
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}(
|
||||
[&](auto i) { incFlag = incFlag | (0x1 << ReduceDims::At(i)); });
|
||||
|
||||
for(int dim = 0; dim < Rank; dim++)
|
||||
{
|
||||
if(incFlag & (0x1 << dim))
|
||||
continue;
|
||||
resDims.push_back(dim);
|
||||
};
|
||||
|
||||
return (resDims);
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
const std::vector<int> reduceDims{0, 1, 2};
|
||||
const std::vector<int> invariantDims{3};
|
||||
|
||||
SimpleAppArgs args;
|
||||
|
||||
if(args.processArgs(argc, argv) < 0)
|
||||
@@ -260,15 +235,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<InDataType> in(args.inLengths);
|
||||
|
||||
const std::vector<int> InvariantDims = get_invariant_dims<Rank, ReduceDims_>();
|
||||
const std::vector<int> ReduceDims = get_reduce_dims<Rank, ReduceDims_>();
|
||||
|
||||
std::vector<size_t> outLengths;
|
||||
|
||||
if(InvariantDims.empty())
|
||||
if(invariantDims.empty())
|
||||
outLengths.push_back(1);
|
||||
else
|
||||
for(auto dim : InvariantDims)
|
||||
for(auto dim : invariantDims)
|
||||
outLengths.push_back(args.inLengths[dim]);
|
||||
|
||||
Tensor<OutDataType> out_ref(outLengths);
|
||||
@@ -328,7 +300,7 @@ int main(int argc, char* argv[])
|
||||
if(args.do_verification)
|
||||
{
|
||||
ReductionHost<InDataType, AccDataType, OutDataType, ReduceOpId, PropagateNan, NeedIndices>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, InvariantDims, ReduceDims);
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(
|
||||
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
|
||||
@@ -350,6 +322,7 @@ int main(int argc, char* argv[])
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
|
||||
Reference in New Issue
Block a user