mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Improve Reduction kernel api (#152)
* Add ThreadwiseReduction functor as per-thread reduction api * Using ThreadwiseReduce api and some change in using PartitionedBlockwiseReduction api to simply the kernels * Add comments and remove useless declarations in the kernels * Tiny updates
This commit is contained in:
@@ -31,6 +31,7 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
@@ -179,10 +180,10 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
@@ -216,14 +217,18 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)p_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
@@ -232,8 +237,8 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
@@ -285,38 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
@@ -414,8 +407,8 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
(void)p_ws_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
@@ -426,15 +419,18 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
auto reduce_work_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
|
||||
auto reduce_work_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
@@ -491,42 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(offset) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
|
||||
in_thread_idx_buf(Number<offset>{}) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
@@ -535,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
@@ -665,8 +654,8 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
(void)in_elementwise_op;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
@@ -681,10 +670,10 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
auto reduce_work_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
|
||||
auto reduce_work_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
@@ -745,8 +734,6 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
// index_t indexOffset = 0;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
@@ -771,42 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
// indexOffset += K_BlockTileSize;
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
@@ -103,10 +104,10 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
@@ -115,6 +116,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -138,15 +145,15 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
@@ -198,42 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
|
||||
// reduced output to the global location corresponding to each invariant dimension to get a
|
||||
// consistent reduced result for that invariant dimension. due to the using of vector_load,
|
||||
// each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
@@ -121,10 +122,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
@@ -151,8 +152,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)acc_elementwise_op;
|
||||
@@ -160,7 +164,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
|
||||
@@ -169,8 +173,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
@@ -222,20 +226,17 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
@@ -243,16 +244,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
@@ -315,8 +308,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ index_t p_block_reduce_idx_buffer[BlockSize];
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ index_t p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
|
||||
@@ -327,10 +320,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
auto reduce_work_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
|
||||
auto reduce_work_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
@@ -394,42 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(offset) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
|
||||
in_thread_idx_buf(Number<offset>{}) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
AccumulationWithIndex::Calculate(tmpValue,
|
||||
in_thread_val_buf[Number<offset>{}],
|
||||
tmpIndex,
|
||||
in_thread_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
@@ -110,6 +111,11 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -124,9 +130,11 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
(void)p_indices_global;
|
||||
|
||||
@@ -175,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedLength += KThreadSliceSize;
|
||||
@@ -200,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
accu_value_buf(I) *= alpha;
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
@@ -266,10 +270,13 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
@@ -282,7 +289,13 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
@@ -322,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
AccumulationWithIndex::Calculate(accu_value_buf(I),
|
||||
in_thread_buf[offset],
|
||||
accu_index_buf(I),
|
||||
indexStart + J);
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduceWithIndex::Reduce(
|
||||
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexStart += KThreadSliceSize;
|
||||
@@ -355,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
accu_value_buf(I) *= alpha;
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user