mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Reduction for int8 and bfloat16 (#125)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX) * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny fix in testing script profile_reduce_no_index.sh * Add support for bfp16 reduction (using bhalf_t = ushort) * Tiny fix in amd_buffer_addressing.hpp * Tiny change in script/profile_reduce_with_index.sh * Use AccDataType for Beta value and use element_wise::PassThrough * Use type_convert for type converting in host layer reduction * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim * Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2 * Update to testing scripts to add bf16 support * added more static_assert * Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp * Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations * minor change * Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass * Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp * Tiny fix in script/profile_reduce_no_index.sh * Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims * Generic renaming in host reduction and DeviceReduce layer * Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances * Use multi-thread and simplification for host Reduction implementation * Add ctest for reduction * Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/ * Update to the reduce CTest executables to enable default testing behavior when no command argument * Renaming Co-authored-by: Jianfeng yan <jfyan008@gmail.com>
This commit is contained in:
@@ -33,6 +33,7 @@
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -52,23 +53,25 @@ __global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
constexpr bool IsSecondCall = false;
|
||||
|
||||
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -102,23 +105,25 @@ kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
constexpr bool IsSecondCall = true;
|
||||
|
||||
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -156,6 +161,11 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -174,8 +184,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -183,17 +192,24 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
template <bool IsSecondCall>
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(IsSecondCall)
|
||||
{
|
||||
static_assert(InSrcVectorDim == 1,
|
||||
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
|
||||
};
|
||||
|
||||
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
@@ -345,7 +361,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -355,7 +371,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -366,7 +382,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
@@ -379,7 +395,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
@@ -570,7 +586,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -580,7 +596,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -591,14 +607,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<index_t>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -609,7 +625,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
@@ -631,11 +647,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_ws_values_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
static_assert(InSrcVectorDim == 1,
|
||||
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
|
||||
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndex<AccDataType,
|
||||
IndexDataType,
|
||||
@@ -841,7 +860,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -851,7 +870,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -862,14 +881,14 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -880,7 +899,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -84,6 +85,11 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -109,8 +115,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -249,7 +254,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
out_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
|
||||
@@ -23,8 +23,8 @@
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -101,6 +102,12 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -119,8 +126,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -238,9 +244,6 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -254,6 +257,9 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_store =
|
||||
@@ -261,7 +267,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
@@ -273,7 +279,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_workspace_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
@@ -450,7 +456,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
@@ -462,14 +468,14 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_workspace_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<IndexDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
@@ -481,7 +487,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_workspace_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -50,7 +51,7 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
@@ -101,11 +102,15 @@ template <typename InDataType,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
@@ -115,7 +120,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
@@ -228,7 +233,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -238,7 +243,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -248,7 +253,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
|
||||
@@ -260,7 +265,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
@@ -387,7 +392,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -397,7 +402,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -407,14 +412,14 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
PassThroughOp,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
@@ -424,7 +429,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
|
||||
|
||||
Reference in New Issue
Block a user