Improve Reduction kernel api (#152)

* Add ThreadwiseReduction functor as per-thread reduction api

* Using ThreadwiseReduce api and some change in using PartitionedBlockwiseReduction api to simply the kernels

* Add comments and remove useless declarations in the kernels

* Tiny updates
This commit is contained in:
Qianfeng
2022-04-05 09:31:44 +08:00
committed by GitHub
parent 646878162b
commit 82c8b9f8ee
6 changed files with 348 additions and 229 deletions

View File

@@ -26,16 +26,20 @@
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "cluster_descriptor.hpp"
namespace ck {
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType,
index_t BlockSize,
typename ThreadClusterLengths_M_K,
@@ -61,8 +65,11 @@ struct PartitionedBlockwiseReduction
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename BufferType>
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData)
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
@@ -71,6 +78,10 @@ struct PartitionedBlockwiseReduction
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
@@ -80,10 +91,10 @@ struct PartitionedBlockwiseReduction
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
AccDataType opData1 = work_buffer[offset1];
AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2);
block_buffer(offset1) = type_convert<AccDataType>(opData1);
work_buffer(offset1) = opData1;
}
__syncthreads();
@@ -91,10 +102,17 @@ struct PartitionedBlockwiseReduction
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_buffer[offset]);
in_out_value = work_buffer[offset];
};
};
// clang-format off
// Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType,
typename IndexDataType,
index_t BlockSize,
@@ -123,11 +141,16 @@ struct PartitionedBlockwiseReductionWithIndex
// This interface accumulates on both data values and indices
template <typename BufferType, typename IdxBufferType>
__device__ static void Reduce(BufferType& block_val_buffer,
IdxBufferType& block_idx_buffer,
AccDataType& accuData,
IndexDataType& accuIndex)
__device__ static void Reduce(BufferType& work_val_buffer,
IdxBufferType& work_idx_buffer,
AccDataType& in_out_value,
IndexDataType& in_out_index)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
static_assert(is_same<typename IdxBufferType::type, IndexDataType>{},
"Buffer data type should be consistent as IndexDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
@@ -136,6 +159,11 @@ struct PartitionedBlockwiseReductionWithIndex
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << I();
@@ -145,14 +173,14 @@ struct PartitionedBlockwiseReductionWithIndex
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
IndexDataType currIndex1 = block_idx_buffer[offset1];
IndexDataType currIndex2 = block_idx_buffer[offset2];
AccDataType opData1 = work_val_buffer[offset1];
AccDataType opData2 = work_val_buffer[offset2];
IndexDataType currIndex1 = work_idx_buffer[offset1];
IndexDataType currIndex2 = work_idx_buffer[offset2];
Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
block_val_buffer(offset1) = type_convert<AccDataType>(opData1);
block_idx_buffer(offset1) = currIndex1;
work_val_buffer(offset1) = opData1;
work_idx_buffer(offset1) = currIndex1;
}
__syncthreads();
@@ -160,9 +188,9 @@ struct PartitionedBlockwiseReductionWithIndex
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
accuIndex = block_idx_buffer[offset];
}
in_out_value = work_val_buffer[offset];
in_out_index = work_idx_buffer[offset];
};
};
}; // end of namespace ck