mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Overhaul to Reducton and its dependants (#237)
* Tiny fix in dynamic_buffer.hpp to support vectorized AtomicAdd for double type * Update to host layer and host reduction * Merge and remove reduction kernels * Merge and remove reduction device interfaces and update pooling device interface * Merge and remove useless reduction device instances * Update to reduction profiler and reduction ctests * Update to reduction and pooling examples and add one reduction example * Change to reduction examples to let them testable by ctest * Add explicit pass checking for reduction and pooling examples * Explicit assignment of tensor shapes in example reduce_blockwise_two_call * Use atomic_add to repace atomicAdd and add atomic_add for double type * Add reduce ctest support for double data type * Replace to_int_vector() by using c++ std::vector::assign() * Keep DeviceReduceThreadWise separated from DeviceReduceBlockWise * Merge DeviceReduceBlockWise and DeviceReduceMultiBlockAtomicAdd into DeviceReduceMultiBlock * Add GetAtomicOperationZeroValue() support for AtomicMax * Tiny change to reduce example README.md * Fix some tiny issues due to branch merging * Revoke previous change in dynamic_buffer.hpp and add atomic_add for double2_t * Add reduce multiblock_atomic_add instances for fp64 to verify vectorized atomic_add on fp64 * Renaming * Clean the header includings in device_reduce instances header files
This commit is contained in:
@@ -17,7 +17,7 @@ template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool NeedIndices,
|
||||
bool OuputIndex,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t ReduceMThreadClusterSize,
|
||||
ck::index_t ReduceKThreadClusterSize,
|
||||
@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
|
||||
static constexpr bool BetaIsZero = true;
|
||||
|
||||
static constexpr index_t InSrcOutDstVectorDim =
|
||||
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
|
||||
// not reduced.
|
||||
@@ -206,28 +204,28 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
false, // propagate_nan
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
ReduceMThreadClusterSize,
|
||||
ReduceKThreadClusterSize,
|
||||
ReduceMThreadSliceSize,
|
||||
ReduceKThreadSliceSize,
|
||||
InSrcOutDstVectorDim,
|
||||
InSrcOutDstVectorSize,
|
||||
InSrcOutDstVectorSize>;
|
||||
using gridwise_reduce =
|
||||
GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // propagate_nan
|
||||
BlockSize,
|
||||
ReduceMThreadSliceSize,
|
||||
ReduceKThreadSliceSize,
|
||||
InSrcOutDstVectorDim,
|
||||
InSrcOutDstVectorSize,
|
||||
InSrcOutDstVectorSize>;
|
||||
|
||||
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
|
||||
NeedIndices,
|
||||
OuputIndex,
|
||||
false, // don't have index input
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
@@ -252,6 +250,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
arg.acc_element_op_,
|
||||
float(1),
|
||||
arg.p_in_dev_,
|
||||
nullptr,
|
||||
float(0),
|
||||
arg.p_out_dev_,
|
||||
arg.p_out_indices_dev_);
|
||||
|
||||
@@ -16,35 +16,18 @@ namespace device {
|
||||
template <typename InElementwiseOperation, typename AccElementwiseOperation>
|
||||
struct DeviceReduce : public BaseOperator
|
||||
{
|
||||
virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
|
||||
const std::vector<int> reduceDims)
|
||||
{
|
||||
(void)inLengths;
|
||||
(void)reduceDims;
|
||||
|
||||
return (0);
|
||||
};
|
||||
|
||||
virtual bool HasFurtherCall() { return (false); };
|
||||
|
||||
virtual std::vector<int> GetWorkspace2dLengths(const BaseArgument* argPtr)
|
||||
{
|
||||
(void)argPtr;
|
||||
return (std::vector<int>{0, 0});
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
const void* in_index_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
void* out_index_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) = 0;
|
||||
|
||||
|
||||
@@ -1,374 +0,0 @@
|
||||
#ifndef DEVICE_REDUCE_BLOCKWISE_HPP
|
||||
#define DEVICE_REDUCE_BLOCKWISE_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K =
|
||||
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto inPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k =
|
||||
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
|
||||
const auto out_grid_desc_m =
|
||||
DeviceReduceBlockWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_reduce_blockwise<GridwiseReduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_,
|
||||
nullptr,
|
||||
arg.out_indices_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[Rank - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// cases with very small reduce_total_length should be handled by the ThreadWise method
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceBlockWise<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,328 +0,0 @@
|
||||
#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
|
||||
#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceBlockWiseSecondCall
|
||||
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
static_assert(
|
||||
std::is_same<InDataType, AccDataType>::value,
|
||||
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{});
|
||||
|
||||
const auto in_grid_desc_m_k =
|
||||
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K =
|
||||
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: inLengths_(inLengths),
|
||||
inStrides_(inStrides),
|
||||
outLengths_(outLengths),
|
||||
outStrides_(outStrides),
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_(in_elementwise_op),
|
||||
acc_elementwise_op_(acc_elementwise_op)
|
||||
{
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
invariant_total_length = inLengths[0];
|
||||
reduce_total_length = inLengths[1];
|
||||
|
||||
invariant_lowest_length = inLengths[0];
|
||||
reduce_lowest_length = inLengths[1];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
|
||||
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
|
||||
invariant_total_length * reduce_total_length * sizeof(AccDataType), 64);
|
||||
|
||||
if constexpr(NeedIndices)
|
||||
workspace_indices_dev_ = reinterpret_cast<index_t*>(
|
||||
reinterpret_cast<char*>(workspace_dev) + ws_buf2_bytes_offset);
|
||||
else
|
||||
workspace_indices_dev_ = nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
IndexDataType* workspace_indices_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_);
|
||||
const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor(
|
||||
arg.outLengths_, arg.outStrides_);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_reduce_blockwise_second_call<GridwiseReduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_,
|
||||
arg.workspace_indices_dev_,
|
||||
arg.out_indices_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// cases with very small reduce_total_length should be handled by the ThreadWise method
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
(void)reduceDims;
|
||||
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -14,13 +14,13 @@ namespace device {
|
||||
|
||||
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
|
||||
// of reduce dims
|
||||
template <int Rank, int NumReduceDim>
|
||||
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
|
||||
template <index_t Rank, int NumReduceDim>
|
||||
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
|
||||
{
|
||||
static_assert(Rank <= 6, "bigger Rank size not supported!");
|
||||
|
||||
size_t invariant_total_length = 1;
|
||||
size_t reduce_total_length = 1;
|
||||
long_index_t invariant_total_length = 1;
|
||||
long_index_t reduce_total_length = 1;
|
||||
|
||||
constexpr int NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
|
||||
auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arraySize>)
|
||||
auto make_tuple_from_array(const std::vector<index_t>& lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumReduceDim>
|
||||
std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides,
|
||||
const std::vector<int>& reduceDims)
|
||||
std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origLengthsStrides,
|
||||
const std::vector<int>& reduceDims)
|
||||
{
|
||||
std::vector<int> newLengthsStrides;
|
||||
std::vector<index_t> newLengthsStrides;
|
||||
|
||||
assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#define DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#ifndef DEVICE_REDUCE_MULTIBLOCK_HPP
|
||||
#define DEVICE_REDUCE_MULTIBLOCK_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
@@ -7,8 +7,9 @@
|
||||
#include "device_base.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_multiblock_atomic_add.hpp"
|
||||
#include "gridwise_2d_reduction_multiblock.hpp"
|
||||
#include "gridwise_set_buffer_value.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -22,8 +23,10 @@ template <typename InDataType,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum OutMemoryDataOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
bool OutputIndex,
|
||||
bool HaveIndexInputIfOutputIndex,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
@@ -32,8 +35,7 @@ template <typename InDataType,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceMultiBlockAtomicAdd
|
||||
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr bool support_AtomicAdd =
|
||||
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
|
||||
// later
|
||||
static constexpr bool use_multiblock =
|
||||
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
|
||||
|
||||
static constexpr bool out_type_compatible_with_atomic_op =
|
||||
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
|
||||
|
||||
static_assert(!NeedIndices && support_AtomicAdd,
|
||||
"MultiBlockAtomicAdd method can only be used with non-indiced operation and when "
|
||||
"having float/double output type!");
|
||||
static_assert(
|
||||
!use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
|
||||
"The OutDataType must support the atomic operation for using MultiBlock reduction");
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
|
||||
"MultiBlock reduction can only be used when outputing index is not required");
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
static_assert(
|
||||
ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
|
||||
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides,
|
||||
int blkGroupSize,
|
||||
int kBlockTileIterations)
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
|
||||
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
|
||||
@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto length = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto pad = math::integer_least_multiple(length, BlockSize) - length;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(length, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
const IndexDataType* in_index_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
IndexDataType* out_index_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
in_index_dev_{in_index_dev},
|
||||
out_dev_{out_dev},
|
||||
out_index_dev_{out_index_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
(void)out_indices_dev;
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
@@ -192,24 +233,35 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize =
|
||||
(reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
iterations++;
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
numBlockTileIteration = iterations;
|
||||
}
|
||||
else
|
||||
{
|
||||
blkGroupSize = 1;
|
||||
numBlockTileIteration =
|
||||
(reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
kBlockTileIterations = iterations;
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize * blkGroupSize;
|
||||
|
||||
@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
std::vector<index_t> inLengths_;
|
||||
std::vector<index_t> inStrides_;
|
||||
std::vector<index_t> outLengths_;
|
||||
std::vector<index_t> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
const IndexDataType* in_index_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_index_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
index_t invariant_lowest_length;
|
||||
index_t reduce_lowest_length;
|
||||
long_index_t invariant_total_length;
|
||||
long_index_t reduce_total_length;
|
||||
|
||||
index_t blkGroupSize;
|
||||
index_t kBlockTileIterations;
|
||||
int blkGroupSize;
|
||||
int numBlockTileIteration;
|
||||
size_t gridSize;
|
||||
|
||||
size_t gridSize_pre;
|
||||
@@ -247,52 +301,69 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
|
||||
const auto out_grid_desc_m = DeviceReduceMultiBlockAtomicAdd::MakeDst1dDescriptor(
|
||||
const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
const auto out_grid_desc_m =
|
||||
DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
|
||||
const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet(
|
||||
arg.outLengths_, arg.outStrides_);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce =
|
||||
GridwiseReduction_mk_to_m_multiblock_atomic_add<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
using OutGridDesc_M_2 = decltype(out_grid_desc_m_2);
|
||||
|
||||
using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
OutMemoryDataOperation,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
const auto kernel_main = kernel_reduce_multiblock<GridwiseReduce,
|
||||
OutputIndex,
|
||||
HaveIndexInput,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
int32_t,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>;
|
||||
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
const auto zeroVal =
|
||||
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
|
||||
OutMemoryDataOperation);
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_pre,
|
||||
dim3(arg.gridSize_pre),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
out_grid_desc_m,
|
||||
arg.out_dev_,
|
||||
static_cast<OutDataType>(0.0f));
|
||||
const auto kernel_pre =
|
||||
kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M_2>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_pre,
|
||||
dim3(arg.gridSize_pre),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
out_grid_desc_m_2,
|
||||
arg.out_dev_,
|
||||
zeroVal);
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
@@ -304,25 +375,34 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.blkGroupSize,
|
||||
arg.kBlockTileIterations,
|
||||
arg.numBlockTileIteration,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.out_dev_);
|
||||
arg.in_index_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_,
|
||||
arg.out_index_dev_);
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
if(static_cast<float>(pArg->beta_) != 0.0f)
|
||||
return (false);
|
||||
};
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
@@ -347,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
return (false);
|
||||
};
|
||||
|
||||
if(static_cast<float>(pArg->beta_) != 0.0f)
|
||||
return (false);
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// cases with small reduce_total_length should be handled by the BlockWise method
|
||||
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
|
||||
return (false);
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
// blkGroupSize of 1 should be handled by Blockwise path using
|
||||
// InMemoryDataOperationEnum::Set
|
||||
if(pArg->blkGroupSize == 1)
|
||||
return (false);
|
||||
|
||||
// This is very strong restriction, but needed to avoid some failure
|
||||
if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
|
||||
return (false);
|
||||
// This is very strong restriction, but needed to avoid some failure
|
||||
if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
// cases with very small reduce_total_length should be handled by ThreadWise kernel
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
};
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
const void* in_index_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
void* out_index_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
@@ -388,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<const IndexDataType*>(in_index_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
static_cast<IndexDataType*>(out_index_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
@@ -1,440 +0,0 @@
|
||||
#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceMultiBlockPartialReduce
|
||||
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static constexpr int MaxBlockGroupSize = 256;
|
||||
|
||||
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
|
||||
const std::vector<int> reduceDims) override
|
||||
{
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
if(testBlkGroupSize <= MaxBlockGroupSize)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
long_index_t workspace_size = invariant_total_length * blkGroupSize;
|
||||
|
||||
long_index_t wsSizeInBytes =
|
||||
!NeedIndices
|
||||
? workspace_size * sizeof(AccDataType)
|
||||
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
|
||||
|
||||
return (wsSizeInBytes);
|
||||
};
|
||||
|
||||
bool HasFurtherCall() override { return (true); };
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
int blkGroupSize,
|
||||
int kBlockTileIterations)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
auto ws_desc_m_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
|
||||
|
||||
const auto wsPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto ws_desc_m_k_padded =
|
||||
transform_tensor_descriptor(ws_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, wsPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (ws_desc_m_k_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
workspace_dev_{workspace_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
alpha_ = type_convert<AccDataType>(alpha);
|
||||
beta_ = type_convert<AccDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
if(testBlkGroupSize <= MaxBlockGroupSize)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
kBlockTileIterations = iterations;
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize * blkGroupSize;
|
||||
|
||||
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
|
||||
invariant_total_length * blkGroupSize * sizeof(AccDataType), 64);
|
||||
|
||||
if constexpr(NeedIndices)
|
||||
workspace_indices_dev_ = reinterpret_cast<int*>(
|
||||
reinterpret_cast<char*>(workspace_dev_) + ws_buf2_bytes_offset);
|
||||
else
|
||||
workspace_indices_dev_ = nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
AccDataType* workspace_dev_;
|
||||
IndexDataType* workspace_indices_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
index_t blkGroupSize;
|
||||
index_t kBlockTileIterations;
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
|
||||
const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor(
|
||||
arg.invariant_total_length, arg.blkGroupSize);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using WorkspaceDesc_M_K = decltype(ws_desc_m_k);
|
||||
|
||||
using GridwiseReduce =
|
||||
GridwiseReduction_mk_to_mk_multiblock_partial_reduce<InDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
WorkspaceDesc_M_K,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_partial_reduce_multiblock<GridwiseReduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
WorkspaceDesc_M_K,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
ws_desc_m_k,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.blkGroupSize,
|
||||
arg.kBlockTileIterations,
|
||||
arg.in_dev_,
|
||||
arg.workspace_dev_,
|
||||
arg.workspace_indices_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(OutDstVectorSize != 1)
|
||||
return (false);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[Rank - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
// cases with small reduce_total_length should be handled by the BlockWise method
|
||||
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::vector<int> GetWorkspace2dLengths(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
return (
|
||||
std::vector<int>{static_cast<int>(pArg->invariant_total_length), pArg->blkGroupSize});
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_multiblock.hpp"
|
||||
#include "gridwise_2d_reduction_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -19,22 +20,19 @@ template <typename InDataType,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
bool OutputIndex,
|
||||
bool HaveIndexInputIfOutputIndex,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutElementwiseOperation>
|
||||
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
|
||||
"Threadwise can only be called with KThreadClusterSize be 1 !");
|
||||
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
IndexDataType* out_index_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op)
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
out_index_dev_{out_index_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
@@ -183,30 +177,33 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
std::vector<index_t> inLengths_;
|
||||
std::vector<index_t> inStrides_;
|
||||
std::vector<index_t> outLengths_;
|
||||
std::vector<index_t> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
IndexDataType* out_index_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
OutElementwiseOperation acc_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
index_t invariant_lowest_length;
|
||||
index_t reduce_lowest_length;
|
||||
long_index_t invariant_total_length;
|
||||
long_index_t reduce_total_length;
|
||||
|
||||
int numBlockTileIteration;
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
PropagateNan,
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
using GridwiseReduce =
|
||||
GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
|
||||
NeedIndices,
|
||||
OutputIndex,
|
||||
HaveIndexInput,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
@@ -252,7 +249,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation>;
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -265,9 +262,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
arg.acc_elementwise_op_,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
nullptr,
|
||||
arg.beta_,
|
||||
arg.out_dev_,
|
||||
arg.out_indices_dev_);
|
||||
arg.out_index_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
@@ -276,7 +274,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
@@ -311,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// TODO: remove this. Should return true, as long as this DeviceOP instance support this
|
||||
// case for bigger reduce_total_length size, we are supposed to use BlockWise method for
|
||||
// better performance
|
||||
// cases with big reduce_total_length should be handled by Blockwise kernel
|
||||
if(pArg->reduce_total_length / KThreadSliceSize >= 32)
|
||||
return (false);
|
||||
|
||||
@@ -321,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int> inLengths,
|
||||
const std::vector<int> inStrides,
|
||||
const std::vector<int> outLengths,
|
||||
const std::vector<int> outStrides,
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
const void* in_index_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
void* out_index_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op) override
|
||||
const AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
(void)in_index_dev;
|
||||
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
@@ -344,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
static_cast<IndexDataType*>(out_index_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
@@ -360,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReducceThreadWise<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "DeviceReduceThreadWise<" << BlockSize << ",";
|
||||
str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -1,886 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
constexpr bool IsSecondCall = false;
|
||||
|
||||
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
constexpr bool IsSecondCall = true;
|
||||
|
||||
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool BetaIsZero,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
template <bool IsSecondCall>
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(IsSecondCall)
|
||||
{
|
||||
static_assert(InSrcVectorDim == 1,
|
||||
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
|
||||
};
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)p_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
|
||||
auto reduce_work_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
index_t indexOffset = 0;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(Number<offset>{}) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexOffset += K_BlockTileSize;
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_index_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_idx_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void
|
||||
RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_ws_values_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
static_assert(InSrcVectorDim == 1,
|
||||
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
|
||||
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
Sequence<MThreadClusterSize, KThreadClusterSize>,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)in_elementwise_op;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
|
||||
auto reduce_work_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_src_idx_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
src_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
threadwise_src_idx_load.Run(in_grid_desc_m_k,
|
||||
src_global_idx_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_index_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,638 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool OutputIndex,
|
||||
bool HaveIndexInput,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
const IndexDataType* const __restrict__ p_in_index_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global,
|
||||
IndexDataType* const __restrict__ p_out_index_global)
|
||||
{
|
||||
if constexpr(!OutputIndex)
|
||||
{
|
||||
(void)p_in_index_global;
|
||||
(void)p_out_index_global;
|
||||
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
alpha,
|
||||
p_in_value_global,
|
||||
beta,
|
||||
p_out_value_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
num_k_block_tile_iteration,
|
||||
alpha,
|
||||
p_in_value_global,
|
||||
p_in_index_global,
|
||||
beta,
|
||||
p_out_value_global,
|
||||
p_out_index_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum OutMemoryDataOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_multiblock
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if(block_group_size == 0 && !float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
OutMemoryDataOperation,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool HaveIndexInput>
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
const IndexDataType* const __restrict__ p_in_index_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global,
|
||||
IndexDataType* const __restrict__ p_out_index_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
Sequence<MThreadClusterSize, KThreadClusterSize>,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)in_elementwise_op;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
|
||||
auto reduce_work_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
|
||||
if constexpr(HaveIndexInput)
|
||||
{
|
||||
auto threadwise_src_idx_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
threadwise_src_idx_load.Run(in_grid_desc_m_k,
|
||||
in_global_idx_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t indexOffset = 0;
|
||||
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(Number<offset>{}) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexOffset += K_BlockTileSize;
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
};
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_index_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,269 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType* const __restrict__ p_out_global)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
alpha,
|
||||
p_in_global,
|
||||
p_out_global);
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType* const __restrict__ p_out_global)
|
||||
{
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
|
||||
// reduced output to the global location corresponding to each invariant dimension to get a
|
||||
// consistent reduced result for that invariant dimension. due to the using of vector_load,
|
||||
// each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,487 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename WorkspaceDesc_M_K,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K workspace_desc_m_k,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
workspace_desc_m_k,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
p_src_global,
|
||||
p_ws_values_global,
|
||||
p_ws_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
|
||||
workspace_desc_m_k,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
p_src_global,
|
||||
p_ws_values_global,
|
||||
p_ws_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename WorkspaceDesc_M_K,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K& workspace_desc_m_k,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_workspace_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_value_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K& workspace_desc_m_k,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ index_t p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
|
||||
auto reduce_work_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t indexOffset = block_local_id * reduceSizePerBlock;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(Number<offset>{}) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexOffset += K_BlockTileSize;
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_workspace_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_workspace_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_value_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_val_buf);
|
||||
threadwise_workspace_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_index_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -37,7 +37,8 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
bool OutputIndex,
|
||||
bool HaveIndexInput,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
const IndexDataType* const __restrict__ p_in_index_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
OutDataType* const __restrict__ p_out_value_global,
|
||||
IndexDataType* const __restrict__ p_out_index_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
if constexpr(!OutputIndex)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
p_in_value_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_indices_global);
|
||||
p_out_value_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndices(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_indices_global);
|
||||
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_value_global,
|
||||
p_in_index_global,
|
||||
beta,
|
||||
p_out_value_global,
|
||||
p_out_index_global);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -91,11 +93,9 @@ template <typename InDataType,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum OutMemoryDataOperation,
|
||||
bool PropagateNan,
|
||||
bool BetaIsZero,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
(void)p_indices_global;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
index_t reducedLength = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
dst_global_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
dst_global_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
OutMemoryDataOperation,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
template <bool HaveIndexInput>
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
const IndexDataType* const __restrict__ p_in_index_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global,
|
||||
IndexDataType* const __restrict__ p_out_index_global)
|
||||
{
|
||||
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
index_t indexStart = 0;
|
||||
index_t reducedLength = 0;
|
||||
do
|
||||
if constexpr(HaveIndexInput)
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
auto threadwise_src_idx_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
do
|
||||
{
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
|
||||
threadwise_src_idx_load.Run(in_grid_desc_m_k,
|
||||
in_global_idx_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduceWithIndex::Reduce(
|
||||
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
|
||||
ThreadwiseReduceWithIndex::Reduce(
|
||||
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexStart += KThreadSliceSize;
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
indexStart += KThreadSliceSize;
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
}
|
||||
else
|
||||
{
|
||||
do
|
||||
{
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
|
||||
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduceWithIndex::Reduce(
|
||||
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexStart += KThreadSliceSize;
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
};
|
||||
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
OutMemoryDataOperation,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
OutMemoryDataOperation,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
|
||||
Reference in New Issue
Block a user