Improve Reduction kernel api (#152)

* Add ThreadwiseReduction functor as per-thread reduction api

* Using ThreadwiseReduce api and some change in using PartitionedBlockwiseReduction api to simply the kernels

* Add comments and remove useless declarations in the kernels

* Tiny updates
This commit is contained in:
Qianfeng
2022-04-05 09:31:44 +08:00
committed by GitHub
parent 646878162b
commit 82c8b9f8ee
6 changed files with 348 additions and 229 deletions

View File

@@ -31,6 +31,7 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
@@ -179,10 +180,10 @@ struct GridwiseReduction_mk_to_m_blockwise
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
@@ -216,14 +217,18 @@ struct GridwiseReduction_mk_to_m_blockwise
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)p_indices_global;
// LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
@@ -232,8 +237,8 @@ struct GridwiseReduction_mk_to_m_blockwise
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
@@ -285,38 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
@@ -414,8 +407,8 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)p_ws_indices_global;
// LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
@@ -426,15 +419,18 @@ struct GridwiseReduction_mk_to_m_blockwise
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
@@ -491,42 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice
AccumulationWithIndex::Calculate(
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
@@ -535,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
@@ -665,8 +654,8 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)in_elementwise_op;
// LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
@@ -681,10 +670,10 @@ struct GridwiseReduction_mk_to_m_blockwise
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
@@ -745,8 +734,6 @@ struct GridwiseReduction_mk_to_m_blockwise
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
// index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
@@ -771,42 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice
AccumulationWithIndex::Calculate(
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
// indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)

View File

@@ -30,6 +30,7 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
@@ -103,10 +104,10 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
@@ -115,6 +116,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
@@ -138,15 +145,15 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
@@ -198,42 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)

View File

@@ -30,6 +30,7 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
@@ -121,10 +122,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
// Dim_K as the fastest one
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
@@ -151,8 +152,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
ReduceOperation,
PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)acc_elementwise_op;
@@ -160,7 +164,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
@@ -169,8 +173,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
@@ -222,20 +226,17 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
@@ -243,16 +244,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
// Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
@@ -315,8 +308,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
__shared__ index_t p_block_reduce_idx_buffer[BlockSize];
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
@@ -327,10 +320,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
auto block_reduce_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
@@ -394,42 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice
AccumulationWithIndex::Calculate(
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);

View File

@@ -30,6 +30,7 @@
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
@@ -110,6 +111,11 @@ struct GridwiseReduction_mk_to_m_threadwise
using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
@@ -124,9 +130,11 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
{
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_indices_global;
@@ -175,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize;
@@ -200,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha;
});
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero)
{
@@ -266,10 +270,13 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global)
{
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
@@ -282,7 +289,13 @@ struct GridwiseReduction_mk_to_m_threadwise
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
@@ -322,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
});
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
AccumulationWithIndex::Calculate(accu_value_buf(I),
in_thread_buf[offset],
accu_index_buf(I),
indexStart + J);
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize;
@@ -355,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha;
});
constexpr auto reduced_data_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero)
{