mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Pr82 followup (#115)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny change in script/profile_reduce_with_index.sh * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim
This commit is contained in:
@@ -32,57 +32,53 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
template <typename AccDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionOn1dBuffer
|
||||
struct PartitionedBlockwiseReduction
|
||||
{
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
|
||||
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
|
||||
|
||||
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& block_buffer,
|
||||
AccDataType& accuData,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData)
|
||||
{
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
|
||||
|
||||
if(thread_k_cluster_id < indOffset)
|
||||
{
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
|
||||
@@ -93,34 +89,34 @@ struct PartitionedBlockwiseReductionOn1dBuffer
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_buffer[offset]);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
struct PartitionedBlockwiseReductionWithIndex
|
||||
{
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
|
||||
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
|
||||
|
||||
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
@@ -130,32 +126,24 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
__device__ static void Reduce(BufferType& block_val_buffer,
|
||||
IdxBufferType& block_idx_buffer,
|
||||
AccDataType& accuData,
|
||||
IndexDataType& accuIndex,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
IndexDataType& accuIndex)
|
||||
{
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << I();
|
||||
|
||||
if(thread_k_cluster_id % (indOffset * 2) == 0)
|
||||
{
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
|
||||
@@ -170,10 +158,7 @@ struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
|
||||
accuIndex = block_idx_buffer[offset];
|
||||
|
||||
@@ -36,14 +36,15 @@ struct DeviceReduce : public BaseOperator
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& inElementwiseOp,
|
||||
const AccElementwiseOperation& accElementwiseOp) = 0;
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -15,8 +15,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -40,7 +40,12 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -74,7 +79,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -82,7 +87,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -136,6 +141,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -144,30 +150,31 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
@@ -305,6 +312,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -318,6 +326,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -15,8 +15,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -45,7 +45,11 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
std::is_same<InDataType, AccDataType>::value,
|
||||
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
|
||||
@@ -117,16 +121,16 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
: inLengths_(inLengths),
|
||||
inStrides_(inStrides),
|
||||
outLengths_(outLengths),
|
||||
outStrides_(outStrides),
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_(in_elementwise_op),
|
||||
acc_elementwise_op_(acc_elementwise_op)
|
||||
{
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
@@ -268,6 +272,7 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -277,6 +282,8 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
{
|
||||
(void)reduceDims;
|
||||
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define DEVICE_REDUCE_COMMON_HPP
|
||||
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
@@ -40,23 +41,6 @@ constexpr bool belong()
|
||||
return (inside);
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims, int start = 0>
|
||||
constexpr auto get_invariant_dims()
|
||||
{
|
||||
static_assert(Rank <= 6, "bigger Rank size not supported!");
|
||||
|
||||
if constexpr(start >= Rank)
|
||||
return Sequence<>{};
|
||||
else
|
||||
{
|
||||
if constexpr(!belong<start, ReduceDims>())
|
||||
return merge_sequences(Sequence<start>{},
|
||||
get_invariant_dims<Rank, ReduceDims, start + 1>());
|
||||
else
|
||||
return get_invariant_dims<Rank, ReduceDims, start + 1>();
|
||||
};
|
||||
};
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
|
||||
@@ -74,6 +58,45 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumReduceDim>
|
||||
static inline std::pair<std::vector<int>, std::vector<int>>
|
||||
shuffle_tensor_dimensions(const std::vector<int>& dimLengths,
|
||||
const std::vector<int>& dimStrides,
|
||||
const std::vector<int>& reduceDims)
|
||||
{
|
||||
std::vector<int> newDimLengths;
|
||||
std::vector<int> newDimStrides;
|
||||
|
||||
assert(Rank == dimLengths.size() && Rank == dimStrides.size() &&
|
||||
NumReduceDim == reduceDims.size());
|
||||
|
||||
int reduceFlag = 0;
|
||||
|
||||
// flag the bits for the reduceDims
|
||||
for(int i = 0; i < NumReduceDim; i++)
|
||||
{
|
||||
reduceFlag |= 1 << reduceDims[i];
|
||||
};
|
||||
|
||||
// collect invariant dimensions
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) == 0)
|
||||
{
|
||||
newDimLengths.push_back(dimLengths[i]);
|
||||
newDimStrides.push_back(dimStrides[i]);
|
||||
};
|
||||
|
||||
// collect reduce dimensions
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) > 0)
|
||||
{
|
||||
newDimLengths.push_back(dimLengths[i]);
|
||||
newDimStrides.push_back(dimStrides[i]);
|
||||
};
|
||||
|
||||
return std::make_pair(newDimLengths, newDimStrides);
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -84,7 +89,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -92,7 +97,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -147,6 +152,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -155,31 +161,31 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
(void)out_indices_dev;
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
@@ -369,6 +375,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -382,6 +389,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -15,8 +15,8 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
@@ -41,7 +41,12 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -112,7 +117,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -120,7 +125,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -161,10 +166,11 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides,
|
||||
const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides,
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -173,31 +179,30 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev},
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
workspace_dev_{workspace_dev}
|
||||
workspace_dev_{workspace_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
@@ -370,6 +375,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -383,6 +389,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
typename ReduceDims,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
@@ -40,7 +40,12 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
using InvariantDims =
|
||||
typename conditional<NumInvariantDim == 0,
|
||||
Sequence<>,
|
||||
typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
@@ -74,7 +79,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
@@ -82,7 +87,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -136,6 +141,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
@@ -144,30 +150,32 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
: outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
std::tie(inLengths_, inStrides_) =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, inStrides, reduceDims);
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths_);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
invariant_lowest_length = inLengths_[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
reduce_lowest_length = inLengths_[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
@@ -306,6 +314,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
const std::vector<int>& reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
@@ -319,6 +328,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
|
||||
@@ -31,8 +31,8 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -158,13 +158,27 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -180,14 +194,12 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
@@ -221,31 +233,31 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
@@ -283,21 +295,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -380,15 +385,13 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -432,31 +435,31 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
index_t indexOffset = 0;
|
||||
|
||||
@@ -503,29 +506,15 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
@@ -648,15 +637,13 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
Sequence<MThreadClusterSize, KThreadClusterSize>,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -707,46 +694,48 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_idx_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
// index_t indexOffset = 0;
|
||||
|
||||
@@ -787,29 +776,15 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
|
||||
@@ -86,22 +86,34 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -145,12 +157,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -158,17 +170,16 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -212,21 +223,14 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
// consistent reduced result for that invariant dimension. due to the using of vector_load,
|
||||
// each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_reduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -103,13 +103,27 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer1dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -124,14 +138,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
@@ -168,12 +180,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -181,17 +193,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -233,21 +244,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
@@ -290,15 +294,13 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -346,12 +348,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -359,17 +361,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -418,29 +419,15 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
|
||||
@@ -101,6 +101,9 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise
|
||||
{
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
@@ -147,17 +150,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
@@ -299,17 +302,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user