mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Pr82 followup (#115)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny change in script/profile_reduce_with_index.sh * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim
This commit is contained in:
@@ -31,8 +31,8 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -158,13 +158,27 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -180,14 +194,12 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
@@ -221,31 +233,31 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
@@ -283,21 +295,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -380,15 +385,13 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -432,31 +435,31 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
index_t indexOffset = 0;
|
||||
|
||||
@@ -503,29 +506,15 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
@@ -648,15 +637,13 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
Sequence<MThreadClusterSize, KThreadClusterSize>,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -707,46 +694,48 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
auto threadwise_src_idx_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
// index_t indexOffset = 0;
|
||||
|
||||
@@ -787,29 +776,15 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
|
||||
@@ -86,22 +86,34 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -145,12 +157,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -158,17 +170,16 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -212,21 +223,14 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
// consistent reduced result for that invariant dimension. due to the using of vector_load,
|
||||
// each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_reduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -103,13 +103,27 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer1dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the
|
||||
// Dim_K as the fastest one
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
@@ -124,14 +138,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
@@ -168,12 +180,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -181,17 +193,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -233,21 +244,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
@@ -290,15 +294,13 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
@@ -346,12 +348,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
@@ -359,17 +361,16 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
@@ -418,29 +419,15 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpValue;
|
||||
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
|
||||
tmpIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
BlockwiseReduceWithIndex::Reduce(
|
||||
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
|
||||
@@ -101,6 +101,9 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise
|
||||
{
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
@@ -147,17 +150,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
@@ -299,17 +302,17 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user