mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Update to the Reduction API and instances (#476)
* Simplify the macros for declaring and defining the add_device_reduce_instance_xxxx() instances * Change the types of lengths and strides from std::vector to std::array for the reduction device interfaces * Remove DeviceSoftmaxImpl's depending on DeviceReduceMultiblock * Split the cpp and hpp files for reduction instances to enable more parallel compiling * Remove the using of macros for declaring reduction instances and instance references * Update to add_device_reduce_instance_xxxx templated functions * Use ReduceOperation+InElementwiseOp+AccElementwiseOp to repace the ReduceOpId in defining add_reduce_instance_xxxx() templates * Change return format
This commit is contained in:
@@ -3,27 +3,30 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation, typename AccElementwiseOperation>
|
||||
template <index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
struct DeviceReduce : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<index_t, NumOutDim> outLengths,
|
||||
const std::array<index_t, NumOutDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -36,9 +39,12 @@ struct DeviceReduce : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation, typename AccElementwiseOperation>
|
||||
using DeviceReducePtr =
|
||||
std::unique_ptr<DeviceReduce<InElementwiseOperation, AccElementwiseOperation>>;
|
||||
template <index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
using DeviceReducePtr = std::unique_ptr<
|
||||
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -5,9 +5,8 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <array>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
|
||||
@@ -41,7 +40,8 @@ template <typename InDataType,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
struct DeviceReduceMultiBlock
|
||||
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr index_t NumSrcDim = Rank;
|
||||
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
|
||||
@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides,
|
||||
static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
|
||||
const std::array<index_t, Rank>& inStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
const auto tupleSrcLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleSrcStrides =
|
||||
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto reduceDimLengths = generate_tuple(
|
||||
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides)
|
||||
static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
|
||||
const std::array<index_t, NumDstDim>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides)
|
||||
static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumDstDim>& outLengths,
|
||||
const std::array<index_t, NumDstDim>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto length = out_grid_desc_m.GetLength(Number<0>{});
|
||||
@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
Argument(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -272,10 +278,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
|
||||
}
|
||||
|
||||
std::vector<index_t> inLengths_;
|
||||
std::vector<index_t> inStrides_;
|
||||
std::vector<index_t> outLengths_;
|
||||
std::vector<index_t> outStrides_;
|
||||
std::array<index_t, Rank> inLengths_;
|
||||
std::array<index_t, Rank> inStrides_;
|
||||
std::array<index_t, NumDstDim> outLengths_;
|
||||
std::array<index_t, NumDstDim> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
@@ -459,11 +465,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <array>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -34,7 +35,8 @@ template <typename InDataType,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
struct DeviceReduceThreadWise
|
||||
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
|
||||
@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t numSrcDim = Rank;
|
||||
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr index_t NumSrcDim = Rank;
|
||||
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides)
|
||||
static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
|
||||
const std::array<index_t, Rank>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
|
||||
const auto tupleSrcLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleSrcStrides =
|
||||
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto reduceDimLengths = generate_tuple(
|
||||
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides)
|
||||
static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
|
||||
const std::array<index_t, NumDstDim>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
Argument(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
|
||||
M_BlockTileSize;
|
||||
}
|
||||
|
||||
std::vector<index_t> inLengths_;
|
||||
std::vector<index_t> inStrides_;
|
||||
std::vector<index_t> outLengths_;
|
||||
std::vector<index_t> outStrides_;
|
||||
std::array<index_t, Rank> inLengths_;
|
||||
std::array<index_t, Rank> inStrides_;
|
||||
std::array<index_t, NumDstDim> outLengths_;
|
||||
std::array<index_t, NumDstDim> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> outStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
|
||||
@@ -8,12 +8,9 @@
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -50,29 +47,80 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
|
||||
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
|
||||
|
||||
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
|
||||
using Reduction = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
reduce::Add,
|
||||
InElementwiseOp,
|
||||
AccElementwiseOp,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // PropagateNan
|
||||
false, // OutputIndex
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1>; // OutDstVectorSize
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
|
||||
static constexpr index_t NumSrcDim = Rank;
|
||||
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleSrcLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleSrcStrides =
|
||||
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths = generate_tuple(
|
||||
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
|
||||
|
||||
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
|
||||
OutDataType,
|
||||
@@ -102,7 +150,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
OutDstVectorSize,
|
||||
true>;
|
||||
|
||||
struct Argument : public Reduction::Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
@@ -113,42 +161,60 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
OutDataType* out_dev,
|
||||
InElementwiseOp in_elementwise_op,
|
||||
AccElementwiseOp acc_elementwise_op)
|
||||
: Reduction::Argument(inLengths,
|
||||
inStrides,
|
||||
{},
|
||||
{},
|
||||
reduceDims,
|
||||
0.0f, // alpha
|
||||
0.0f, // beta
|
||||
in_dev,
|
||||
nullptr,
|
||||
out_dev,
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op),
|
||||
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
|
||||
// float32 precision. Make it support any data type so the fields can be removed.
|
||||
alpha_(alpha),
|
||||
beta_(beta)
|
||||
: alpha_{alpha},
|
||||
beta_{beta},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
// std::cout << "blkGroupSize= " << this->blkGroupSize
|
||||
// << ", numBlockTileIteration= " << this->numBlockTileIteration
|
||||
// << ", gridSize=" << this->gridSize
|
||||
// << ", invariant_total_length=" << this->invariant_total_length <<
|
||||
// std::endl;
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
long_index_t invariant_total_length;
|
||||
long_index_t reduce_total_length;
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length_ = 1;
|
||||
else
|
||||
invariant_lowest_length_ = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
blkGroupSize = 1;
|
||||
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize * blkGroupSize;
|
||||
}
|
||||
|
||||
std::vector<index_t> inLengths_;
|
||||
std::vector<index_t> inStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
|
||||
InElementwiseOp in_elementwise_op_;
|
||||
AccElementwiseOp acc_elementwise_op_;
|
||||
|
||||
index_t invariant_lowest_length_;
|
||||
|
||||
int blkGroupSize;
|
||||
int numBlockTileIteration;
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
|
||||
bool sweep_once =
|
||||
@@ -195,15 +261,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(!Reduction::IsSupportedArgument(p_arg_))
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(p_arg_->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
if(p_arg_->invariant_lowest_length_ % InSrcVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(p_arg_->inStrides_[Rank - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(p_arg_->inLengths_[Rank - 1] % InSrcVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
if(p_arg_->invariant_lowest_length_ % OutDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user