mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Improve Reduction kernel api (#152)
* Add ThreadwiseReduction functor as per-thread reduction api * Using ThreadwiseReduce api and some change in using PartitionedBlockwiseReduction api to simply the kernels * Add comments and remove useless declarations in the kernels * Tiny updates
This commit is contained in:
@@ -26,16 +26,20 @@
|
||||
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
|
||||
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
|
||||
#include "cluster_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// clang-format off
|
||||
// Assume:
|
||||
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
|
||||
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
|
||||
// 3) in_out_value is the input data in vgpr from each thread
|
||||
// 4) in_out_value is the over-written reduced output in vgpr for each thread
|
||||
// clang-format on
|
||||
template <typename AccDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
@@ -61,8 +65,11 @@ struct PartitionedBlockwiseReduction
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData)
|
||||
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
|
||||
{
|
||||
static_assert(is_same<typename BufferType::type, AccDataType>{},
|
||||
"Buffer data type should be consistent as AccDataType!");
|
||||
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
@@ -71,6 +78,10 @@ struct PartitionedBlockwiseReduction
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
|
||||
|
||||
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
|
||||
|
||||
@@ -80,10 +91,10 @@ struct PartitionedBlockwiseReduction
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
|
||||
AccDataType opData1 = work_buffer[offset1];
|
||||
AccDataType opData2 = work_buffer[offset2];
|
||||
Accumulation::Calculate(opData1, opData2);
|
||||
block_buffer(offset1) = type_convert<AccDataType>(opData1);
|
||||
work_buffer(offset1) = opData1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -91,10 +102,17 @@ struct PartitionedBlockwiseReduction
|
||||
|
||||
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_buffer[offset]);
|
||||
in_out_value = work_buffer[offset];
|
||||
};
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
// Assume:
|
||||
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
|
||||
// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
|
||||
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
|
||||
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
|
||||
// clang-format on
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
@@ -123,11 +141,16 @@ struct PartitionedBlockwiseReductionWithIndex
|
||||
|
||||
// This interface accumulates on both data values and indices
|
||||
template <typename BufferType, typename IdxBufferType>
|
||||
__device__ static void Reduce(BufferType& block_val_buffer,
|
||||
IdxBufferType& block_idx_buffer,
|
||||
AccDataType& accuData,
|
||||
IndexDataType& accuIndex)
|
||||
__device__ static void Reduce(BufferType& work_val_buffer,
|
||||
IdxBufferType& work_idx_buffer,
|
||||
AccDataType& in_out_value,
|
||||
IndexDataType& in_out_index)
|
||||
{
|
||||
static_assert(is_same<typename BufferType::type, AccDataType>{},
|
||||
"Buffer data type should be consistent as AccDataType!");
|
||||
static_assert(is_same<typename IdxBufferType::type, IndexDataType>{},
|
||||
"Buffer data type should be consistent as IndexDataType!");
|
||||
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
@@ -136,6 +159,11 @@ struct PartitionedBlockwiseReductionWithIndex
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
|
||||
|
||||
work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
|
||||
work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << I();
|
||||
|
||||
@@ -145,14 +173,14 @@ struct PartitionedBlockwiseReductionWithIndex
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
|
||||
IndexDataType currIndex1 = block_idx_buffer[offset1];
|
||||
IndexDataType currIndex2 = block_idx_buffer[offset2];
|
||||
AccDataType opData1 = work_val_buffer[offset1];
|
||||
AccDataType opData2 = work_val_buffer[offset2];
|
||||
IndexDataType currIndex1 = work_idx_buffer[offset1];
|
||||
IndexDataType currIndex2 = work_idx_buffer[offset2];
|
||||
|
||||
Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
|
||||
block_val_buffer(offset1) = type_convert<AccDataType>(opData1);
|
||||
block_idx_buffer(offset1) = currIndex1;
|
||||
work_val_buffer(offset1) = opData1;
|
||||
work_idx_buffer(offset1) = currIndex1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -160,9 +188,9 @@ struct PartitionedBlockwiseReductionWithIndex
|
||||
|
||||
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
|
||||
accuIndex = block_idx_buffer[offset];
|
||||
}
|
||||
in_out_value = work_val_buffer[offset];
|
||||
in_out_index = work_idx_buffer[offset];
|
||||
};
|
||||
};
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
Reference in New Issue
Block a user