mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Reduction in Composable Kernel (#82)
* Initial adding of generic reduction
* Initial adding of generic reduction ...
* Updates to make compiling done
* clang-format all files
* clang-format some files again
* Renaming in profiler/include/profile_reduce.hpp
* Updates and make BlockWise cases passed
* Updates and make ThreadWise and MultiBlockTwoCall cases passed
* Remove the support for MUL and NORM1 reduceOp from the profiler and the device instances
* Change to replace the dim0_max_vector_size/dim1_max_vector_size template argument in the device reduce classes
* format
* adding pooling
* added max and average pooling
* comment out cout and kernel timing
* Tiny simplification in profiler/reduce_profiler.cpp
* Add example for reduce_blockwise
* Tiny updates
* Change to pass the ElementWiseOp from device layer to kernel
* Fix the vectorDim and vectorSize in Device layer
* Enable vector load on both dim0 and dim1 for Threadwise method
* Tiny updates
* Change to let the user to pass the preUnaryOp and posUnaryOp
* Make pooling example work
* split device_reduce_instance into two libraries
* Tiny update
* Replace nanPropaOpt enum by boolean propagate_nan
* Simplification in DeviceReduce layer codes
* update build
* Change to clarify the difference between ck::half_t and half_float::half
* Renaming in all the reduction codes
* Add VectorSize as template parameter for device layer
* Add BetaIsZero as kernel template and as AccDataType for alpha
* print
* Small updates for pooling
* Updates for host_generic_reduction for reference
* Update to make AVG pooling pass
* Update to make MAX pooling with indices output pass
* fix
* add OutDst vector store to threadwise reduction and pooling
* tweak
* turn off check_indices that caused build issue
* refactor pooling
* clean up
* turn off check_indices for building issue for php-compiler
* add more tile size for odd C
* tweak conv for odd C
* update script
* clean up elementwise op
* add hack in reduction_operator.hpp to avoid compile error. To fix it, need to use element_wise_op in reduction op
* Add OutVectorSize as device and kernel tunable, also update to Elementwise Operations
* Move reduce operator mapping to host layer file reduction_operator_mapping.hpp from reduction_operator.hpp
* Change to the unary operators
* Move the definitions of unary operations to element_wise_operation.hpp
* re-org files
* Refine in device interfaces and multiblock kernels
* Split the reduction configurations into instances for specific methods
* Update in getTypeString() of device pool2d
* Renaming in host and kernel
* Tiny update in profiler/src/profiler.cpp
* Uncomment in device_operation/CMakeLists.txt to enable the building of all operations
* Make check_indices a templated function to remove some linking issue
* Renaming in the profiler reduce module
* Add support for double Reduction (but disable MultiblockAtomicAdd for double)
* Tiny correction of literal string
* Rename DevicePoolFwd to DevicePool2dFwd
* Split device_reduce_instance_xxx.cpp files according to the data types to speed up compiling
* Add comments for lists of configurations, lists of instances and references of add_reduce_instances_xxx
* Remove un-used header file gridwise_generic_reduction_wrapper_common.hpp
* Renaming and refining in the Reduction codes
* Tiny change in the unary operators
* Renaming symbols and files
* Renaming symbols in the kernels
* Move kernel kernel_set_buffer_value to separate file
* Add IndexDataType template parameter for kernels and use int32_t as index data type in device layer
* Tiny update in the kernels
* Remove definition of sqrtf()/isnan()/abs() for half_t due to some ADL issue
* Simplify a helper function in device layer
* Tiny adjustment in testing data initialization
* Renaming in kernel/device/host
* Add two testing scripts for reduction
* Refine the Unary operators in element_wise_operation.hpp
* Update in the reduce profiler module
* Update to the reduction testing scripts
* reduce compile parallelism
* change CI docker to rocm5.0
* remove unused variables
* fix build
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: e17c0d8008]
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
FROM ubuntu:18.04
|
||||
|
||||
ARG ROCMVERSION=4.3.1
|
||||
ARG ROCMVERSION=5.0
|
||||
ARG OSDB_BKC_VERSION
|
||||
|
||||
RUN set -xe
|
||||
|
||||
@@ -175,6 +175,161 @@ struct RequantReluRequant
|
||||
float scaleRelu_;
|
||||
};
|
||||
|
||||
// Unary operators are usually called element-wisely before/after the reduction is executed on the
|
||||
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
|
||||
|
||||
template <typename Y, typename X, bool HasDividing = false>
|
||||
struct UnaryIdentic;
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<float, float, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<float, float, true>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const
|
||||
{
|
||||
y = x / type_convert<float>(divider_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<half_t, half_t, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<double, double, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<double, double, true>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const
|
||||
{
|
||||
y = x / type_convert<double>(divider_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryIdentic<int32_t, int32_t, false>
|
||||
{
|
||||
__host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; };
|
||||
};
|
||||
|
||||
template <typename Y, typename X, bool HasDividing = false>
|
||||
struct UnarySquare;
|
||||
|
||||
template <>
|
||||
struct UnarySquare<float, float, false>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = x * x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySquare<float, float, true>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const
|
||||
{
|
||||
y = x * x / type_convert<float>(divider_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySquare<double, double, false>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = x * x; };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySquare<double, double, true>
|
||||
{
|
||||
__host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const
|
||||
{
|
||||
y = x * x / type_convert<double>(divider_);
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
struct UnaryAbs;
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<float, float>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<half_t, half_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnaryAbs<double, double>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
struct UnarySqrt;
|
||||
|
||||
template <>
|
||||
struct UnarySqrt<float, float>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct UnarySqrt<double, double>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); };
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,925 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_ws_indices_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool BetaIsZero,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)p_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
index_t indexOffset = 0;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(offset) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexOffset += K_BlockTileSize;
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<index_t>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_index_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_idx_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void
|
||||
RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_ws_values_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
const IndexDataType* const __restrict__ p_ws_indices_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)in_elementwise_op;
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_ws_values_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
IndexDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
// index_t indexOffset = 0;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
src_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
threadwise_src_idx_load.Run(in_grid_desc_m_k,
|
||||
src_global_idx_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
// indexOffset += K_BlockTileSize;
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < toReduceTiles);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValueBuf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(block_global_1d_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<index_t>{});
|
||||
|
||||
threadwise_dst_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_value_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
accu_index_buf,
|
||||
out_grid_desc_m,
|
||||
out_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,268 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType* const __restrict__ p_out_global)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
alpha,
|
||||
p_in_global,
|
||||
p_out_global);
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer_1d_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
|
||||
using blockwise_reduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer_1d_desc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType* const __restrict__ p_out_global)
|
||||
{
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
|
||||
// reduced output to the global location corresponding to each invariant dimension to get a
|
||||
// consistent reduced result for that invariant dimension. due to the using of vector_load,
|
||||
// each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_reduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,514 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_TWO_CALL_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename WorkspaceDesc_M_K,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void
|
||||
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K workspace_desc_m_k,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
workspace_desc_m_k,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
p_src_global,
|
||||
p_ws_values_global,
|
||||
p_ws_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
|
||||
workspace_desc_m_k,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
p_src_global,
|
||||
p_ws_values_global,
|
||||
p_ws_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename WorkspaceDesc_M_K,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
{
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
static constexpr auto buffer1dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<BlockSize>{}));
|
||||
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K& workspace_desc_m_k,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduce = PartitionedBlockwiseReductionOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
(void)p_ws_indices_global;
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
// Each block executes multiple parallel reductions on the LDS, and due to the using of
|
||||
// vector_load, each block/thread is involved into multiple invarirant dimensions.
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_buf(thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
}
|
||||
else
|
||||
block_reduce_buf(thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id) =
|
||||
accu_value_buf[I];
|
||||
|
||||
accu_value_buf(I) = zeroVal;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduce::Reduce(
|
||||
block_reduce_buf, accu_value_buf(I), thread_m_cluster_id, thread_k_cluster_id);
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_workspace_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_value_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const WorkspaceDesc_M_K& workspace_desc_m_k,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const InDataType* const __restrict__ p_src_global,
|
||||
AccDataType* const __restrict__ p_ws_values_global,
|
||||
IndexDataType* const __restrict__ p_ws_indices_global)
|
||||
{
|
||||
using BlockwiseReduceWithIndex =
|
||||
PartitionedBlockwiseReductionWithIndexOn1dBuffer<decltype(buffer1dDesc),
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
reorder_thread_cluster,
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize];
|
||||
__shared__ index_t p_block_reduce_idx_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_val_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
const index_t thread_m_cluster_id =
|
||||
reorder_thread_cluster ? thread_local_id % MThreadClusterSize
|
||||
: ((thread_local_id / KThreadClusterSize) % MThreadClusterSize);
|
||||
const index_t thread_k_cluster_id =
|
||||
reorder_thread_cluster ? ((thread_local_id / MThreadClusterSize) % KThreadClusterSize)
|
||||
: thread_local_id % KThreadClusterSize;
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
index_t indexOffset = block_local_id * reduceSizePerBlock;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
// load the thread slice
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// initialize the indices for the per-thread to-reduce values
|
||||
in_thread_idx_buf(offset) =
|
||||
indexOffset + thread_k_cluster_id * KThreadSliceSize + J();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
// reduce on the dim1 thread slice
|
||||
AccumulationWithIndex::Calculate(
|
||||
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]);
|
||||
});
|
||||
|
||||
// store thread local value to LDS for parallel reduction
|
||||
if constexpr(reorder_thread_cluster)
|
||||
{
|
||||
block_reduce_val_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_k_cluster_id * MThreadClusterSize +
|
||||
thread_m_cluster_id) = tmpIndex;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_reduce_val_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpValue;
|
||||
block_reduce_idx_buf(thread_m_cluster_id * KThreadClusterSize +
|
||||
thread_k_cluster_id) = tmpIndex;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
BlockwiseReduceWithIndex::Reduce(block_reduce_val_buf,
|
||||
block_reduce_idx_buf,
|
||||
tmpValue,
|
||||
tmpIndex,
|
||||
thread_m_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
AccumulationWithIndex::Calculate(
|
||||
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex);
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexOffset += K_BlockTileSize;
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
AccDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_workspace_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
WorkspaceDesc_M_K,
|
||||
PassThroughOp<IndexDataType>,
|
||||
Sequence<MThreadSliceSize, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
|
||||
threadwise_workspace_val_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_value_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_val_buf);
|
||||
threadwise_workspace_idx_store.Run(reduced_data_desc,
|
||||
make_tuple(I0, I0),
|
||||
accu_index_buf,
|
||||
workspace_desc_m_k,
|
||||
workspace_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,435 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP
|
||||
#define CK_GRIDWISE_2D_REDUCTION_THREADWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
bool NeedIndices,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation>
|
||||
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_indices_global);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseReduction::RunWithIndices(in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op,
|
||||
alpha,
|
||||
p_in_global,
|
||||
beta,
|
||||
p_out_global,
|
||||
p_indices_global);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool BetaIsZero,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise
|
||||
{
|
||||
template <typename T>
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
(void)p_indices_global;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
index_t reducedLength = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
dst_global_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
|
||||
};
|
||||
|
||||
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_global,
|
||||
OutDataType beta,
|
||||
OutDataType* const __restrict__ p_out_global,
|
||||
IndexDataType* const __restrict__ p_indices_global)
|
||||
{
|
||||
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<
|
||||
InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
index_t indexStart = 0;
|
||||
index_t reducedLength = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
|
||||
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset));
|
||||
});
|
||||
|
||||
// reduce on each thread-local slice
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
|
||||
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
|
||||
AccumulationWithIndex::Calculate(accu_value_buf(I),
|
||||
in_thread_buf[offset],
|
||||
accu_index_buf(I),
|
||||
indexStart + J);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
indexStart += KThreadSliceSize;
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
|
||||
// for indiced operation, acc_elementwise_op shoud do nothing
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
|
||||
|
||||
accu_value_buf(I) *= alpha;
|
||||
});
|
||||
|
||||
constexpr auto reduced_data_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
if constexpr(!BetaIsZero)
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
OutDataType,
|
||||
OutGridDesc_M,
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
out_global_val_buf,
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I] * beta);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<AccDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<AccDataType>{});
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThroughOp<IndexDataType>,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThroughOp<IndexDataType>{});
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
|
||||
|
||||
threadwise_dst_idx_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,623 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP
|
||||
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename srcDataType,
|
||||
typename dstDataType,
|
||||
typename compType,
|
||||
typename src2dDescType,
|
||||
typename dst1dDescType,
|
||||
ReduceTensorOp_t op,
|
||||
NanPropagation_t nanPropaOpt,
|
||||
ReduceTensorIndices_t reduceIndicesOpt,
|
||||
bool isFirstCall,
|
||||
bool isLastCall,
|
||||
index_t GredAccessesPerThreadInBlock>
|
||||
struct GridwiseReduction_xy_to_x_blockwise
|
||||
{
|
||||
using opReduce = typename reduce_binary_operator<compType, op>::opType;
|
||||
using preUnaryOpType =
|
||||
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::preUnaryOp;
|
||||
using posUnaryOpType =
|
||||
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::posUnaryOp;
|
||||
|
||||
static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<GredAccessesPerThreadInBlock>{}, Number<BlockSize>{}));
|
||||
using blockwise_reduce =
|
||||
BlockwiseReduction_2d_block_buffer<decltype(buffer2dDesc), true, opReduce, nanPropaOpt>;
|
||||
|
||||
static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <int RunId>
|
||||
__device__ static void Run(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global);
|
||||
|
||||
template <>
|
||||
__device__ static void Run<1>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
(void)indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
auto in_block_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
const posUnaryOpType posUnaryOp(divider);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
|
||||
constexpr auto in_block_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<BlockBufferSize>{}));
|
||||
|
||||
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
|
||||
using ThreadClusterLengths = Sequence<1, BlockSize>;
|
||||
|
||||
auto blockwise_src_load =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, BlockBufferSize>,
|
||||
ThreadSliceLengths,
|
||||
ThreadClusterLengths,
|
||||
Sequence<0, 1>,
|
||||
srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(in_block_desc),
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
true>(src2dDesc,
|
||||
make_multi_index(block_global_1d_id, 0),
|
||||
in_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
|
||||
|
||||
const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize;
|
||||
|
||||
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
|
||||
reducedBlocks += GredAccessesPerThreadInBlock)
|
||||
{
|
||||
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
|
||||
blockwise_src_load.RunWrite(in_block_desc, in_block_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf);
|
||||
|
||||
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
|
||||
? GredAccessesPerThreadInBlock
|
||||
: toReduceBlocks - reducedBlocks;
|
||||
blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0));
|
||||
|
||||
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
|
||||
}
|
||||
|
||||
accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]);
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
// The first thread in the block stores the reduced result to the global location
|
||||
// representing the block
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ static void Run<2>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
|
||||
// LDS
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
__shared__ int block_indices_buffer[BlockBufferSize];
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
indices_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
auto in_block_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
|
||||
auto in_block_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(block_indices_buffer, BlockBufferSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
accuIndex_buf(I0) = 0;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
|
||||
constexpr auto in_block_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<BlockBufferSize>{}));
|
||||
|
||||
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
|
||||
using ThreadClusterLengths = Sequence<1, BlockSize>;
|
||||
|
||||
auto blockwise_src_load =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, BlockBufferSize>,
|
||||
ThreadSliceLengths,
|
||||
ThreadClusterLengths,
|
||||
Sequence<0, 1>,
|
||||
srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(in_block_desc),
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
true>(src2dDesc,
|
||||
make_multi_index(block_global_1d_id, 0),
|
||||
in_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
|
||||
|
||||
const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize;
|
||||
|
||||
int indexOffset = 0;
|
||||
|
||||
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
|
||||
reducedBlocks += GredAccessesPerThreadInBlock)
|
||||
{
|
||||
// load block data from global to LDS, no use of double buffers (to be improved)
|
||||
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
|
||||
blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// construct the indices for the current toReduce blocks
|
||||
blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset);
|
||||
|
||||
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
|
||||
// done here
|
||||
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf);
|
||||
|
||||
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
|
||||
? GredAccessesPerThreadInBlock
|
||||
: toReduceBlocks - reducedBlocks;
|
||||
|
||||
blockwise_reduce::Reduce2(in_block_val_buf,
|
||||
in_block_idx_buf,
|
||||
BlocksInOneOp,
|
||||
accuValue_buf(I0),
|
||||
accuIndex_buf(I0));
|
||||
|
||||
indexOffset += BlockBufferSize;
|
||||
|
||||
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
// The first thread in the block stores the reduced result to the global location
|
||||
// representing the block
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(dst1dDesc,
|
||||
dst_global_val_buf,
|
||||
ReducedDataDesc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int,
|
||||
int,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ static void Run<3>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ ws_values_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)origReduceLen;
|
||||
|
||||
// LDS
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
__shared__ int block_indices_buffer[BlockBufferSize];
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_indices_global, src2dDesc.GetElementSpaceSize());
|
||||
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
indices_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
auto in_block_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
|
||||
auto in_block_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(block_indices_buffer, BlockBufferSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
accuIndex_buf(I0) = 0;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_1d_id = get_block_1d_id();
|
||||
|
||||
constexpr auto in_block_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<BlockBufferSize>{}));
|
||||
|
||||
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
|
||||
using ThreadClusterLengths = Sequence<1, BlockSize>;
|
||||
|
||||
auto blockwise_src_val_load =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, BlockBufferSize>,
|
||||
ThreadSliceLengths,
|
||||
ThreadClusterLengths,
|
||||
Sequence<0, 1>,
|
||||
srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(in_block_desc),
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
true>(src2dDesc,
|
||||
make_multi_index(block_global_1d_id, 0),
|
||||
in_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
auto blockwise_src_idx_load =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, BlockBufferSize>,
|
||||
ThreadSliceLengths,
|
||||
ThreadClusterLengths,
|
||||
Sequence<0, 1>,
|
||||
int,
|
||||
int,
|
||||
src2dDescType,
|
||||
decltype(in_block_desc),
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
true>(src2dDesc,
|
||||
make_multi_index(block_global_1d_id, 0),
|
||||
in_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
|
||||
|
||||
const index_t toReduceBlocks = (toReduceLength + BlockSize - 1) / BlockSize;
|
||||
|
||||
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
|
||||
reducedBlocks += GredAccessesPerThreadInBlock)
|
||||
{
|
||||
// load block data from global to LDS, no use of double buffers (to be improved)
|
||||
blockwise_src_val_load.RunRead(src2dDesc, src_global_val_buf);
|
||||
blockwise_src_idx_load.RunRead(src2dDesc, src_global_idx_buf);
|
||||
blockwise_src_val_load.RunWrite(in_block_desc, in_block_val_buf);
|
||||
blockwise_src_idx_load.RunWrite(in_block_desc, in_block_idx_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
|
||||
? GredAccessesPerThreadInBlock
|
||||
: toReduceBlocks - reducedBlocks;
|
||||
|
||||
blockwise_reduce::Reduce2(in_block_val_buf,
|
||||
in_block_idx_buf,
|
||||
BlocksInOneOp,
|
||||
accuValue_buf(I0),
|
||||
accuIndex_buf(I0));
|
||||
|
||||
blockwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
|
||||
blockwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
// The first thread in the block stores the reduced result to the global location
|
||||
// representing the block
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(dst1dDesc,
|
||||
dst_global_val_buf,
|
||||
ReducedDataDesc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int,
|
||||
int,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,501 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP
|
||||
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename srcDataType,
|
||||
typename dstDataType,
|
||||
typename compType,
|
||||
typename src2dDescType,
|
||||
typename dst1dDescType,
|
||||
ReduceTensorOp_t op,
|
||||
NanPropagation_t nanPropaOpt,
|
||||
ReduceTensorIndices_t reduceIndicesOpt,
|
||||
bool isFirstCall,
|
||||
bool isLastCall,
|
||||
index_t GredThreadBufferLength>
|
||||
struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
{
|
||||
using opReduce = typename reduce_binary_operator<compType, op>::opType;
|
||||
using preUnaryOpType =
|
||||
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::preUnaryOp;
|
||||
using posUnaryOpType =
|
||||
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::posUnaryOp;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <int RunId>
|
||||
__device__ static void Run(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global);
|
||||
|
||||
template <>
|
||||
__device__ static void Run<1>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
(void)indices_global;
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredThreadBufferLength, true>
|
||||
in_thread_buf;
|
||||
|
||||
using threadwise_reduce = ThreadReduce<decltype(in_thread_buf), opReduce, nanPropaOpt>;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
const posUnaryOpType posUnaryOp(divider);
|
||||
|
||||
using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>;
|
||||
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<GredThreadBufferLength>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc, make_multi_index(thread_global_1d_id, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength);
|
||||
|
||||
for(index_t reducedLength = 0; reducedLength < toReduceLength;
|
||||
reducedLength += GredThreadBufferLength)
|
||||
{
|
||||
threadwise_src_load.Run(
|
||||
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
|
||||
|
||||
// do the reduction on the Thread Buffer
|
||||
threadwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0));
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
}
|
||||
|
||||
accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]);
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
dst1dDesc, make_multi_index(thread_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ static void Run<2>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
indices_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredThreadBufferLength, true>
|
||||
in_thread_buf;
|
||||
|
||||
using threadwise_reduce = ThreadReduce<decltype(in_thread_buf), opReduce, nanPropaOpt>;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
accuIndex_buf(I0) = 0;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
|
||||
using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>;
|
||||
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<GredThreadBufferLength>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc, make_multi_index(thread_global_1d_id, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength);
|
||||
|
||||
index_t indexStart = 0;
|
||||
for(index_t reducedLength = 0; reducedLength < toReduceLength;
|
||||
reducedLength += GredThreadBufferLength)
|
||||
{
|
||||
threadwise_src_load.Run(
|
||||
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
|
||||
|
||||
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
|
||||
// done here
|
||||
threadwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
|
||||
|
||||
// do the reduction on the Thread Buffer
|
||||
threadwise_reduce::Reduce2(
|
||||
in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexStart);
|
||||
|
||||
indexStart += GredThreadBufferLength;
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
dst1dDesc, make_multi_index(thread_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int,
|
||||
int,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ static void Run<3>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ ws_values_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)origReduceLen;
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_indices_global, src2dDesc.GetElementSpaceSize());
|
||||
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
indices_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredThreadBufferLength, true>
|
||||
in_thread_val_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, GredThreadBufferLength, true> in_thread_idx_buf;
|
||||
|
||||
using threadwise_reduce = ThreadReduceWithIndicesInput<decltype(in_thread_val_buf),
|
||||
decltype(in_thread_idx_buf),
|
||||
opReduce,
|
||||
nanPropaOpt>;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
accuIndex_buf(I0) = 0;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
using ThreadBufferLengths = Sequence<1, GredThreadBufferLength>;
|
||||
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<GredThreadBufferLength>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc, make_multi_index(thread_global_1d_id, 0));
|
||||
|
||||
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<int,
|
||||
int,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc, make_multi_index(thread_global_1d_id, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, GredThreadBufferLength);
|
||||
|
||||
for(index_t reducedLength = 0; reducedLength < toReduceLength;
|
||||
reducedLength += GredThreadBufferLength)
|
||||
{
|
||||
threadwise_src_val_load.Run(src2dDesc,
|
||||
src_global_val_buf,
|
||||
ThreadBufferDesc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
threadwise_src_idx_load.Run(src2dDesc,
|
||||
src_global_idx_buf,
|
||||
ThreadBufferDesc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
// do the reduction on the Thread Buffer
|
||||
threadwise_reduce::Reduce(
|
||||
in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0));
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
dst1dDesc, make_multi_index(thread_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int,
|
||||
int,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
false>(dst1dDesc,
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,542 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP
|
||||
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_warpwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename srcDataType,
|
||||
typename dstDataType,
|
||||
typename compType,
|
||||
typename src2dDescType,
|
||||
typename dst1dDescType,
|
||||
ReduceTensorOp_t op,
|
||||
NanPropagation_t nanPropaOpt,
|
||||
ReduceTensorIndices_t reduceIndicesOpt,
|
||||
bool isFirstCall,
|
||||
bool isLastCall,
|
||||
index_t GredAccessesPerThreadInWarp>
|
||||
struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
{
|
||||
using opReduce = typename reduce_binary_operator<compType, op>::opType;
|
||||
using preUnaryOpType =
|
||||
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::preUnaryOp;
|
||||
using posUnaryOpType =
|
||||
typename reduce_unary_operator<compType, op, isFirstCall, isLastCall>::posUnaryOp;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <int RunId>
|
||||
__device__ static void Run(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global);
|
||||
|
||||
template <>
|
||||
__device__ static void Run<1>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
(void)indices_global;
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredAccessesPerThreadInWarp, true>
|
||||
in_thread_buf;
|
||||
|
||||
using warpwise_reduce =
|
||||
WarpReduce<decltype(in_thread_buf), BlockSize, opReduce, nanPropaOpt>;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
const posUnaryOpType posUnaryOp(divider);
|
||||
|
||||
using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>;
|
||||
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<GredAccessesPerThreadInWarp>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
index_t warp_global_1d_id = thread_global_1d_id / warpSize;
|
||||
index_t thread_inwarp_id = thread_global_1d_id % warpSize;
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc,
|
||||
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
|
||||
|
||||
constexpr auto in_thread_copy_step =
|
||||
make_multi_index(0, warpSize * GredAccessesPerThreadInWarp);
|
||||
|
||||
for(index_t reducedLength = 0; reducedLength < toReduceLength;
|
||||
reducedLength += warpSize * GredAccessesPerThreadInWarp)
|
||||
{
|
||||
threadwise_src_load.Run(
|
||||
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
|
||||
|
||||
// do the warp-wise reduction on data of all thread buffers
|
||||
warpwise_reduce::Reduce(in_thread_buf, accuValue_buf(I0));
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
}
|
||||
|
||||
accuValue_buf(I0) = posUnaryOp(accuValue_buf[I0]);
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
// The first thread in the warp stores the reduced result to the global location
|
||||
// representing the Warp
|
||||
if(thread_inwarp_id == 0)
|
||||
{
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf(I0) * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ static void Run<2>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
indices_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredAccessesPerThreadInWarp, true>
|
||||
in_thread_buf;
|
||||
|
||||
using warpwise_reduce =
|
||||
WarpReduce<decltype(in_thread_buf), BlockSize, opReduce, nanPropaOpt>;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
accuIndex_buf(I0) = 0;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
|
||||
using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>;
|
||||
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<GredAccessesPerThreadInWarp>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
index_t warp_global_1d_id = thread_global_1d_id / warpSize;
|
||||
index_t thread_inwarp_id = thread_global_1d_id % warpSize;
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc,
|
||||
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
|
||||
|
||||
constexpr auto in_thread_copy_step =
|
||||
make_multi_index(0, warpSize * GredAccessesPerThreadInWarp);
|
||||
|
||||
index_t indexOffset = 0;
|
||||
for(index_t reducedLength = 0; reducedLength < toReduceLength;
|
||||
reducedLength += warpSize * GredAccessesPerThreadInWarp)
|
||||
{
|
||||
threadwise_src_load.Run(
|
||||
src2dDesc, src_global_buf, ThreadBufferDesc, make_tuple(I0, I0), in_thread_buf);
|
||||
|
||||
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
|
||||
// done here
|
||||
warpwise_reduce::operate_on_elements(preUnaryOp, in_thread_buf);
|
||||
|
||||
// do the warp-wise reduction on data of all thread buffers
|
||||
warpwise_reduce::Reduce2(
|
||||
in_thread_buf, accuValue_buf(I0), accuIndex_buf(I0), indexOffset);
|
||||
|
||||
indexOffset += warpSize * GredAccessesPerThreadInWarp;
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
// The first thread in the warp stores the reduced result to the global location
|
||||
// representing the Warp
|
||||
if(thread_inwarp_id == 0)
|
||||
{
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(dst1dDesc,
|
||||
dst_global_val_buf,
|
||||
ReducedDataDesc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int,
|
||||
int,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ static void Run<3>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ ws_values_global,
|
||||
dstDataType beta,
|
||||
dstDataType* const __restrict__ p_dst_global,
|
||||
const int* const __restrict__ ws_indices_global,
|
||||
int* const __restrict__ indices_global)
|
||||
{
|
||||
(void)origReduceLen;
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_indices_global, src2dDesc.GetElementSpaceSize());
|
||||
auto dst_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_dst_global, dst1dDesc.GetElementSpaceSize());
|
||||
auto dst_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
indices_global, dst1dDesc.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, GredAccessesPerThreadInWarp, true>
|
||||
in_thread_val_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, GredAccessesPerThreadInWarp, true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
using warpwise_reduce = WarpReduceWithIndicesInput<decltype(in_thread_val_buf),
|
||||
decltype(in_thread_idx_buf),
|
||||
BlockSize,
|
||||
opReduce,
|
||||
nanPropaOpt>;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
accuIndex_buf(I0) = 0;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
using ThreadBufferLengths = Sequence<1, GredAccessesPerThreadInWarp>;
|
||||
constexpr auto ThreadBufferDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<GredAccessesPerThreadInWarp>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
index_t warp_global_1d_id = thread_global_1d_id / warpSize;
|
||||
index_t thread_inwarp_id = thread_global_1d_id % warpSize;
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc,
|
||||
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
|
||||
|
||||
auto threadwise_src_idx_load = ThreadwiseTensorSliceTransfer_v2<int,
|
||||
int,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false>(
|
||||
src2dDesc,
|
||||
make_multi_index(warp_global_1d_id, thread_inwarp_id * GredAccessesPerThreadInWarp));
|
||||
|
||||
constexpr auto in_thread_copy_step =
|
||||
make_multi_index(0, warpSize * GredAccessesPerThreadInWarp);
|
||||
|
||||
for(index_t reducedLength = 0; reducedLength < toReduceLength;
|
||||
reducedLength += warpSize * GredAccessesPerThreadInWarp)
|
||||
{
|
||||
threadwise_src_val_load.Run(src2dDesc,
|
||||
src_global_val_buf,
|
||||
ThreadBufferDesc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
threadwise_src_idx_load.Run(src2dDesc,
|
||||
src_global_idx_buf,
|
||||
ThreadBufferDesc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
// do the warp-wise reduction on data of all thread buffers
|
||||
warpwise_reduce::Reduce(
|
||||
in_thread_val_buf, in_thread_idx_buf, accuValue_buf(I0), accuIndex_buf(I0));
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
threadwise_src_idx_load.MoveSrcSliceWindow(src2dDesc, in_thread_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
// The first thread in the warp stores the reduced result to the global location
|
||||
// representing the Warp
|
||||
if(thread_inwarp_id == 0)
|
||||
{
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
dstDataType,
|
||||
dst1dDescType,
|
||||
decltype(ReducedDataDesc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(dst1dDesc,
|
||||
dst_global_val_buf,
|
||||
ReducedDataDesc,
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
auto threadwise_dst_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int,
|
||||
int,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(dst1dDesc,
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,376 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
|
||||
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename srcDataType,
|
||||
typename dstDataType, // not used together with the beta input
|
||||
typename compType,
|
||||
typename src2dDescType,
|
||||
typename dst1dDescType,
|
||||
ReduceTensorOp_t op,
|
||||
NanPropagation_t nanPropaOpt,
|
||||
ReduceTensorIndices_t reduceIndicesOpt,
|
||||
index_t GredAccessesPerThreadInBlock>
|
||||
struct GridwiseReduction_xy_to_x_multiblock
|
||||
{
|
||||
using opReduce = typename reduce_binary_operator<compType, op>::opType;
|
||||
using preUnaryOpType = typename reduce_unary_operator<compType, op, true, false>::preUnaryOp;
|
||||
using posUnaryOpType = typename reduce_unary_operator<compType, op, true, false>::posUnaryOp;
|
||||
|
||||
static constexpr auto buffer2dDesc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<GredAccessesPerThreadInBlock>{}, Number<BlockSize>{}));
|
||||
using blockwise_reduce =
|
||||
BlockwiseReduction_2d_block_buffer<decltype(buffer2dDesc), true, opReduce, nanPropaOpt>;
|
||||
|
||||
static constexpr index_t BlockBufferSize = buffer2dDesc.GetElementSize();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <int RunId>
|
||||
__device__ static void Run(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
srcDataType* const __restrict__ ws_values_global,
|
||||
int* const __restrict__ ws_indices_global);
|
||||
|
||||
template <>
|
||||
__device__ static void Run<1>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
srcDataType* const __restrict__ ws_values_global,
|
||||
int* const __restrict__ ws_indices_global)
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
|
||||
(void)alpha; // unused
|
||||
(void)beta; // unused
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
|
||||
|
||||
auto in_block_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_buffer, BlockBufferSize);
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / BlkGroupSize;
|
||||
const index_t block_local_id = block_global_id % BlkGroupSize;
|
||||
|
||||
const index_t reduceSizePerBlock =
|
||||
(((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) /
|
||||
BlockBufferSize) *
|
||||
BlockBufferSize;
|
||||
|
||||
constexpr auto in_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<BlockSize * GredAccessesPerThreadInBlock>{}));
|
||||
|
||||
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
|
||||
using ThreadClusterLengths = Sequence<1, BlockSize>;
|
||||
|
||||
auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, BlockBufferSize>,
|
||||
ThreadSliceLengths,
|
||||
ThreadClusterLengths,
|
||||
Sequence<0, 1>,
|
||||
srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(in_block_desc),
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
true>(
|
||||
src2dDesc,
|
||||
make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock),
|
||||
in_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
|
||||
|
||||
const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize;
|
||||
|
||||
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
|
||||
reducedBlocks += GredAccessesPerThreadInBlock)
|
||||
{
|
||||
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
|
||||
blockwise_src_load.RunWrite(in_block_desc, in_block_buf);
|
||||
__syncthreads();
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_buf);
|
||||
|
||||
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
|
||||
? GredAccessesPerThreadInBlock
|
||||
: toReduceBlocks - reducedBlocks;
|
||||
blockwise_reduce::Reduce(in_block_buf, BlocksInOneOp, accuValue_buf(I0));
|
||||
|
||||
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
const auto workspace_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize));
|
||||
|
||||
// The first thread in the block stores the reduced result to the global location
|
||||
// representing the block
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
srcDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
decltype(workspace_desc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(workspace_desc,
|
||||
make_multi_index(block_global_id));
|
||||
|
||||
threadwise_workspace_store.Run(ReducedDataDesc,
|
||||
make_tuple(I0),
|
||||
accuValue_buf,
|
||||
workspace_desc,
|
||||
workspace_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ static void Run<2>(const src2dDescType& src2dDesc,
|
||||
const dst1dDescType& dst1dDesc,
|
||||
int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
srcDataType alpha,
|
||||
const srcDataType* const __restrict__ p_src_global,
|
||||
dstDataType beta,
|
||||
srcDataType* const __restrict__ ws_values_global,
|
||||
int* const __restrict__ ws_indices_global)
|
||||
{
|
||||
(void)alpha; // unused
|
||||
(void)beta; // unused
|
||||
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ compType p_in_block_values_buffer[BlockBufferSize];
|
||||
__shared__ int p_in_block_indices_buffer[BlockBufferSize];
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>(zeroVal));
|
||||
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
|
||||
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
ws_indices_global, dst1dDesc.GetLength(I0) * BlkGroupSize);
|
||||
|
||||
auto in_block_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_in_block_values_buffer, BlockBufferSize);
|
||||
auto in_block_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_in_block_indices_buffer, BlockBufferSize);
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, compType, 1, true> accuValue_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, int, 1, true> accuIndex_buf;
|
||||
|
||||
accuValue_buf(I0) = zeroVal;
|
||||
accuIndex_buf(I0) = 0;
|
||||
|
||||
const auto toReduceLength = src2dDesc.GetLength(Number<1>{});
|
||||
const int divider = origReduceLen;
|
||||
|
||||
const preUnaryOpType preUnaryOp(divider);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / BlkGroupSize;
|
||||
const index_t block_local_id = block_global_id % BlkGroupSize;
|
||||
|
||||
const index_t reduceSizePerBlock =
|
||||
(((toReduceLength + BlkGroupSize - 1) / BlkGroupSize + BlockBufferSize - 1) /
|
||||
BlockBufferSize) *
|
||||
BlockBufferSize;
|
||||
|
||||
constexpr auto in_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<1>{}, Number<BlockSize * GredAccessesPerThreadInBlock>{}));
|
||||
|
||||
using ThreadSliceLengths = Sequence<1, GredAccessesPerThreadInBlock>;
|
||||
using ThreadClusterLengths = Sequence<1, BlockSize>;
|
||||
|
||||
auto blockwise_src_load = BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, BlockBufferSize>,
|
||||
ThreadSliceLengths,
|
||||
ThreadClusterLengths,
|
||||
Sequence<0, 1>,
|
||||
srcDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(in_block_desc),
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
true>(
|
||||
src2dDesc,
|
||||
make_multi_index(blkgroup_id, block_local_id * reduceSizePerBlock),
|
||||
in_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto in_block_copy_step = make_multi_index(0, BlockBufferSize);
|
||||
|
||||
const index_t toReduceBlocks = (reduceSizePerBlock + BlockSize - 1) / BlockSize;
|
||||
|
||||
int indexOffset = block_local_id * reduceSizePerBlock;
|
||||
|
||||
for(index_t reducedBlocks = 0; reducedBlocks < toReduceBlocks;
|
||||
reducedBlocks += GredAccessesPerThreadInBlock)
|
||||
{
|
||||
blockwise_reduce::init_buffer_indices(in_block_idx_buf, indexOffset);
|
||||
|
||||
blockwise_src_load.RunRead(src2dDesc, src_global_buf);
|
||||
blockwise_src_load.RunWrite(in_block_desc, in_block_val_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
|
||||
// done here
|
||||
blockwise_reduce::operate_on_elements(preUnaryOp, in_block_val_buf);
|
||||
|
||||
index_t BlocksInOneOp = (reducedBlocks < toReduceBlocks - GredAccessesPerThreadInBlock)
|
||||
? GredAccessesPerThreadInBlock
|
||||
: toReduceBlocks - reducedBlocks;
|
||||
|
||||
blockwise_reduce::Reduce2(in_block_val_buf,
|
||||
in_block_idx_buf,
|
||||
BlocksInOneOp,
|
||||
accuValue_buf(I0),
|
||||
accuIndex_buf(I0));
|
||||
|
||||
indexOffset += BlockBufferSize;
|
||||
|
||||
blockwise_src_load.MoveSrcSliceWindow(src2dDesc, in_block_copy_step);
|
||||
}
|
||||
|
||||
constexpr auto ReducedDataDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
const auto workspace_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(dst1dDesc.GetLength(I0) * BlkGroupSize));
|
||||
|
||||
// The first thread in the block stores the reduced result to the global location
|
||||
// representing the block
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
auto threadwise_workspace_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
srcDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
decltype(workspace_desc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(workspace_desc,
|
||||
make_multi_index(block_global_id));
|
||||
|
||||
auto threadwise_workspace_idx_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int,
|
||||
int,
|
||||
decltype(ReducedDataDesc),
|
||||
decltype(workspace_desc),
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(workspace_desc,
|
||||
make_multi_index(block_global_id));
|
||||
|
||||
threadwise_workspace_val_store.Run(ReducedDataDesc,
|
||||
make_tuple(I0),
|
||||
accuValue_buf,
|
||||
workspace_desc,
|
||||
workspace_global_val_buf);
|
||||
threadwise_workspace_idx_store.Run(ReducedDataDesc,
|
||||
make_tuple(I0),
|
||||
accuIndex_buf,
|
||||
workspace_desc,
|
||||
workspace_global_idx_buf);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,79 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_GRIDWISE_SET_BUFFER_VALUE_HPP
|
||||
#define CK_GRIDWISE_SET_BUFFER_VALUE_HPP
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize, typename DataType, typename Grid1dBufferDescType>
|
||||
__global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffer_desc,
|
||||
DataType* const __restrict__ p_global,
|
||||
DataType value)
|
||||
|
||||
{
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
const index_t thread_global_id = block_global_id * BlockSize + thread_local_id;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, DataType, 1, true> value_buf;
|
||||
|
||||
value_buf(I0) = value;
|
||||
|
||||
constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
auto global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_global, grid_1d_buffer_desc.GetElementSpaceSize());
|
||||
|
||||
if(thread_global_id < grid_1d_buffer_desc.GetElementSize())
|
||||
{
|
||||
auto threadwise_store = ThreadwiseTensorSliceTransfer_v1r3<DataType,
|
||||
DataType,
|
||||
decltype(val_buff_desc),
|
||||
Grid1dBufferDescType,
|
||||
PassThroughOp,
|
||||
Sequence<1>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
1,
|
||||
true>(
|
||||
grid_1d_buffer_desc, make_multi_index(thread_global_id), PassThroughOp{});
|
||||
|
||||
threadwise_store.Run(
|
||||
val_buff_desc, make_tuple(I0), value_buf, grid_1d_buffer_desc, global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -30,240 +30,154 @@
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_binop.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename buffer2dDescType,
|
||||
bool blockIsOneRow,
|
||||
typename opReduce,
|
||||
NanPropagation_t nanPropaOpt>
|
||||
struct BlockwiseReduction_2d_block_buffer
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionOn1dBuffer
|
||||
{
|
||||
using compType = typename opReduce::dataType;
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static constexpr auto buffer2dDesc = buffer2dDescType{};
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
blockIsOneRow ? buffer2dDesc.GetLength(Number<1>{}) : buffer2dDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t NumBlocks =
|
||||
blockIsOneRow ? buffer2dDesc.GetLength(Number<0>{}) : buffer2dDesc.GetLength(Number<1>{});
|
||||
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
// This interface does not accumulate on indices
|
||||
template <typename BufferType>
|
||||
__device__ static void
|
||||
Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData)
|
||||
__device__ static void Reduce(BufferType& block_buffer,
|
||||
AccDataType& accuData,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
{
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
|
||||
index_t offset;
|
||||
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
|
||||
{
|
||||
offset = blockIsOneRow
|
||||
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id))
|
||||
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
|
||||
compType opData = type_convert<compType>(block_buffer[offset]);
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
|
||||
|
||||
binop::calculate(lAccuData, opData);
|
||||
}
|
||||
|
||||
offset = blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id))
|
||||
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
|
||||
|
||||
block_buffer(offset) = lAccuData;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t indOffset = BlockSize / 2; indOffset > 0; indOffset /= 2)
|
||||
{
|
||||
if(thread_local_id < indOffset)
|
||||
if(thread_k_cluster_id < indOffset)
|
||||
{
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id))
|
||||
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
|
||||
index_t offset2 =
|
||||
blockIsOneRow
|
||||
? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset))
|
||||
: buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
|
||||
|
||||
compType opData1 = type_convert<compType>(block_buffer[offset1]);
|
||||
compType opData2 = type_convert<compType>(block_buffer[offset2]);
|
||||
binop::calculate(opData1, opData2);
|
||||
block_buffer(offset1) = type_convert<compType>(opData1);
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
|
||||
Accumulation::Calculate(opData1, opData2);
|
||||
block_buffer(offset1) = type_convert<AccDataType>(opData1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
});
|
||||
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
compType tmpVal = type_convert<compType>(block_buffer[0]);
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
|
||||
binop::calculate(accuData, tmpVal);
|
||||
}
|
||||
accuData = type_convert<AccDataType>(block_buffer[offset]);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
{
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
// This interface accumulates on both data values and indices
|
||||
template <typename BufferType, typename IdxBufferType>
|
||||
__device__ static void Reduce2(BufferType& block_buffer,
|
||||
IdxBufferType& block_indices_buffer,
|
||||
index_t toReduceBlocks,
|
||||
compType& accuData,
|
||||
int& accuIndex)
|
||||
__device__ static void Reduce(BufferType& block_val_buffer,
|
||||
IdxBufferType& block_idx_buffer,
|
||||
AccDataType& accuData,
|
||||
IndexDataType& accuIndex,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
{
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
|
||||
if constexpr(blockIsOneRow)
|
||||
{
|
||||
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << I();
|
||||
|
||||
if(thread_k_cluster_id % (indOffset * 2) == 0)
|
||||
{
|
||||
for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2)
|
||||
{
|
||||
if(thread_local_id % (indOffset * 2) == 0)
|
||||
{
|
||||
index_t offset1 =
|
||||
buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id));
|
||||
index_t offset2 = buffer2dDesc.CalculateOffset(
|
||||
make_tuple(otherDimInd, thread_local_id + indOffset));
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
|
||||
compType currVal1 = type_convert<compType>(block_buffer[offset1]);
|
||||
compType currVal2 = type_convert<compType>(block_buffer[offset2]);
|
||||
int currIndex1 = block_indices_buffer[offset1];
|
||||
int currIndex2 = block_indices_buffer[offset2];
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
|
||||
IndexDataType currIndex1 = block_idx_buffer[offset1];
|
||||
IndexDataType currIndex2 = block_idx_buffer[offset2];
|
||||
|
||||
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
|
||||
block_buffer(offset1) = type_convert<compType>(currVal1);
|
||||
block_indices_buffer(offset1) = currIndex1;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
|
||||
block_val_buffer(offset1) = type_convert<AccDataType>(opData1);
|
||||
block_idx_buffer(offset1) = currIndex1;
|
||||
}
|
||||
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
|
||||
{
|
||||
index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0));
|
||||
|
||||
compType tmpVal = type_convert<compType>(block_buffer[offset]);
|
||||
int tmpIndex = block_indices_buffer[offset];
|
||||
|
||||
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
|
||||
}
|
||||
|
||||
binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t offset;
|
||||
|
||||
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
|
||||
{
|
||||
offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd));
|
||||
compType currVal = type_convert<compType>(block_buffer[offset]);
|
||||
int currIndex = block_indices_buffer[offset];
|
||||
|
||||
binop::calculate(lAccuData, currVal, lAccuIndex, currIndex);
|
||||
}
|
||||
|
||||
offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
|
||||
|
||||
block_buffer(offset) = lAccuData;
|
||||
block_indices_buffer(offset) = lAccuIndex;
|
||||
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
for(index_t indOffset = 1; indOffset < BlockSize; indOffset *= 2)
|
||||
{
|
||||
if(thread_local_id % (indOffset * 2) == 0)
|
||||
{
|
||||
index_t offset1 = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, 0));
|
||||
index_t offset2 =
|
||||
buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0));
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
|
||||
compType currVal1 = type_convert<compType>(block_buffer[offset1]);
|
||||
compType currVal2 = type_convert<compType>(block_buffer[offset2]);
|
||||
int currIndex1 = block_indices_buffer[offset1];
|
||||
int currIndex2 = block_indices_buffer[offset2];
|
||||
|
||||
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
|
||||
block_buffer(offset1) = type_convert<compType>(currVal1);
|
||||
block_indices_buffer(offset1) = currIndex1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if(thread_local_id == 0)
|
||||
{
|
||||
compType tmpVal = type_convert<compType>(block_buffer[0]);
|
||||
int tmpIndex = block_indices_buffer[0];
|
||||
|
||||
binop::calculate(accuData, tmpVal, accuIndex, tmpIndex);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void set_buffer_value(BufferType& block_buffer, compType value)
|
||||
{
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++)
|
||||
{
|
||||
index_t offset = blockIsOneRow
|
||||
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id))
|
||||
: buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd));
|
||||
|
||||
block_buffer(offset) = value;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize the block-wise indices buffer, the index for each element in the block-wise
|
||||
// data buffer is calculated according to its position in the buffer and the global starting
|
||||
// index
|
||||
template <typename IdxBufferType>
|
||||
__device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart)
|
||||
{
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++)
|
||||
{
|
||||
index_t offset = blockIsOneRow
|
||||
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id))
|
||||
: buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd));
|
||||
|
||||
block_indices_buffer(offset) = offset + indexStart;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
// Execute unary operation on the block buffer elements
|
||||
template <typename unary_op_type, typename BufferType>
|
||||
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& block_buffer)
|
||||
{
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
for(index_t otherDimInd = 0; otherDimInd < NumBlocks; otherDimInd++)
|
||||
{
|
||||
index_t offset = blockIsOneRow
|
||||
? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_id))
|
||||
: buffer2dDesc.CalculateOffset(make_tuple(thread_id, otherDimInd));
|
||||
|
||||
block_buffer(offset) = unary_op(block_buffer[offset]);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
|
||||
accuIndex = block_idx_buffer[offset];
|
||||
}
|
||||
};
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
|
||||
#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_binop.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename BufferType, typename opReduce, NanPropagation_t nanPropaOpt>
|
||||
struct ThreadReduce
|
||||
{
|
||||
using compType = typename opReduce::dataType;
|
||||
|
||||
static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!");
|
||||
|
||||
static_assert(
|
||||
std::is_same<typename BufferType::type, compType>::value,
|
||||
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!");
|
||||
|
||||
static constexpr index_t ThreadBufferLen = BufferType::Size();
|
||||
|
||||
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
|
||||
|
||||
// This interface does not accumulate on indices
|
||||
__device__ static void Reduce(const BufferType& thread_buffer, compType& accuData)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { binop::calculate(accuData, thread_buffer[I]); });
|
||||
};
|
||||
|
||||
// This interface accumulates on both data values and indices and
|
||||
// is called by Direct_ThreadWise reduction method at first-time reduction
|
||||
__device__ static void
|
||||
Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
|
||||
int currIndex = I + indexStart;
|
||||
binop::calculate(accuData, thread_buffer[I], accuIndex, currIndex);
|
||||
});
|
||||
};
|
||||
|
||||
// Set the elements in the per-thread buffer to a specific value
|
||||
// cppcheck-suppress constParameter
|
||||
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
|
||||
};
|
||||
|
||||
// Execute unary operation on the per-thread buffer elements
|
||||
template <typename unary_op_type>
|
||||
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
|
||||
};
|
||||
};
|
||||
|
||||
template <typename BufferType,
|
||||
typename IdxBufferType,
|
||||
typename opReduce,
|
||||
NanPropagation_t nanPropaOpt>
|
||||
struct ThreadReduceWithIndicesInput
|
||||
{
|
||||
using compType = typename opReduce::dataType;
|
||||
|
||||
static_assert(BufferType::IsStaticBuffer(), "Thread-wise reduction needs use StaticBuffer!");
|
||||
static_assert(IdxBufferType::IsStaticBuffer(),
|
||||
"Thread-wise reduction needs use StaticBuffer for indices!");
|
||||
|
||||
static_assert(
|
||||
std::is_same<typename BufferType::type, compType>::value,
|
||||
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!");
|
||||
static_assert(std::is_same<typename IdxBufferType::type, index_t>::value,
|
||||
"Indices type of StaticBuffer for Thread-wise reduction should be index_t!");
|
||||
|
||||
static_assert(BufferType::Size() == IdxBufferType::Size(),
|
||||
"StaticBuffers for data and indices should have the same sizes!");
|
||||
|
||||
static constexpr index_t ThreadBufferLen = BufferType::Size();
|
||||
|
||||
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
|
||||
|
||||
// This interface accumulates on both data values and indices and
|
||||
// is called by Direct_ThreadWise reduction method at second-time reduction
|
||||
__device__ static void Reduce(const BufferType& thread_buffer,
|
||||
const IdxBufferType& thread_indices_buffer,
|
||||
compType& accuData,
|
||||
int& accuIndex)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
|
||||
binop::calculate(accuData, thread_buffer[I], accuIndex, thread_indices_buffer[I]);
|
||||
});
|
||||
};
|
||||
|
||||
// Set the elements in the per-thread buffer to a specific value
|
||||
// cppcheck-suppress constParameter
|
||||
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
|
||||
};
|
||||
|
||||
// Execute unary operation on the per-thread buffer elements
|
||||
template <typename unary_op_type>
|
||||
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
|
||||
};
|
||||
};
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,371 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
|
||||
#define CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_binop.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename BufferType, index_t BlockSize, typename opReduce, NanPropagation_t nanPropaOpt>
|
||||
struct WarpReduce
|
||||
{
|
||||
using compType = typename opReduce::dataType;
|
||||
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
|
||||
|
||||
static_assert(BufferType::IsStaticBuffer(),
|
||||
"Per-thread buffer for WarpWise reduction should be StaticBuffer!");
|
||||
static_assert(std::is_same<typename BufferType::type, compType>::value,
|
||||
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
|
||||
"the compType!");
|
||||
|
||||
static constexpr index_t ThreadBufferLen = BufferType::Size();
|
||||
static constexpr bool have_builtin_shuffle =
|
||||
std::is_same<compType, float>::value || std::is_same<compType, double>::value;
|
||||
|
||||
// This interface does not accumulate on indices
|
||||
__device__ static void Reduce(const BufferType& thread_buffer, compType& accuData)
|
||||
{
|
||||
if constexpr(have_builtin_shuffle)
|
||||
ReduceImpl1(thread_buffer, accuData);
|
||||
else
|
||||
ReduceImpl2(thread_buffer, accuData);
|
||||
};
|
||||
|
||||
// This interface implementation uses HIP built-in device shuffling functions
|
||||
__device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData)
|
||||
{
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
|
||||
|
||||
// synchronize among all threads in this warp
|
||||
__all(1);
|
||||
|
||||
for(index_t stride = warpSize / 2; stride > 0; stride /= 2)
|
||||
{
|
||||
compType tmpVal = __shfl_down(lAccuData, stride, warpSize);
|
||||
binop::calculate(lAccuData, tmpVal);
|
||||
__all(1);
|
||||
}
|
||||
|
||||
binop::calculate(accuData, lAccuData);
|
||||
};
|
||||
|
||||
// This interface implementation does not use HIP built-in device shuffling functions
|
||||
// since for fp16, built-in shuffling functions is not provided by HIP
|
||||
__device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData)
|
||||
{
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
|
||||
|
||||
__syncthreads();
|
||||
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
index_t warpId = thread_id / warpSize;
|
||||
index_t thread_inwarp_id = thread_id % warpSize;
|
||||
|
||||
__shared__ compType shuffle_buffer[BlockSize];
|
||||
|
||||
compType* myBuffer = &shuffle_buffer[warpId * warpSize];
|
||||
|
||||
myBuffer[thread_inwarp_id] = lAccuData;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t stride = warpSize / 2; stride > 0; stride /= 2)
|
||||
{
|
||||
if(thread_inwarp_id < stride)
|
||||
{
|
||||
compType currVal1 = myBuffer[thread_inwarp_id];
|
||||
compType currVal2 = myBuffer[thread_inwarp_id + stride];
|
||||
|
||||
binop::calculate(currVal1, currVal2);
|
||||
|
||||
myBuffer[thread_inwarp_id] = currVal1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
if(thread_inwarp_id == 0)
|
||||
binop::calculate(accuData, myBuffer[0]);
|
||||
};
|
||||
|
||||
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
|
||||
// reduction method at first-time reduction
|
||||
__device__ static void
|
||||
Reduce2(const BufferType& thread_buffer, compType& accuData, int& accuIndex, int indexStart)
|
||||
{
|
||||
if constexpr(have_builtin_shuffle)
|
||||
Reduce2Impl1(thread_buffer, accuData, accuIndex, indexStart);
|
||||
else
|
||||
Reduce2Impl2(thread_buffer, accuData, accuIndex, indexStart);
|
||||
};
|
||||
|
||||
// This interface implementation uses HIP built-in device shuffling functions
|
||||
__device__ static void Reduce2Impl1(const BufferType& thread_buffer,
|
||||
compType& accuData,
|
||||
int& accuIndex,
|
||||
int indexStart)
|
||||
{
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize;
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
|
||||
int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart;
|
||||
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex);
|
||||
});
|
||||
|
||||
// synchronize among all threads in this warp
|
||||
__all(1);
|
||||
|
||||
for(index_t stride = 1; stride < warpSize; stride *= 2)
|
||||
{
|
||||
compType tmpVal = __shfl_down(lAccuData, stride, warpSize);
|
||||
int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize);
|
||||
|
||||
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
|
||||
__all(1);
|
||||
}
|
||||
|
||||
if(thread_inwarp_id == 0)
|
||||
binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex);
|
||||
};
|
||||
|
||||
// This interface implementation does not use HIP built-in device shuffling functions since for
|
||||
// fp16, built-in shuffling functions is not provided by HIP
|
||||
__device__ static void Reduce2Impl2(const BufferType& thread_buffer,
|
||||
compType& accuData,
|
||||
int& accuIndex,
|
||||
int indexStart)
|
||||
{
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
index_t warpId = thread_id / warpSize;
|
||||
index_t thread_inwarp_id = thread_id % warpSize;
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
|
||||
int currIndex = thread_inwarp_id * ThreadBufferLen + I + indexStart;
|
||||
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, currIndex);
|
||||
});
|
||||
|
||||
__shared__ compType shuffle_data_buffer[BlockSize];
|
||||
__shared__ int shuffle_indices_buffer[BlockSize];
|
||||
|
||||
compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize];
|
||||
int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize];
|
||||
|
||||
myDataBuffer[thread_inwarp_id] = lAccuData;
|
||||
myIndicesBuffer[thread_inwarp_id] = lAccuIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t stride = 1; stride < warpSize; stride *= 2)
|
||||
{
|
||||
compType currVal1 = myDataBuffer[thread_inwarp_id];
|
||||
compType currVal2 = myDataBuffer[thread_inwarp_id + stride];
|
||||
int currIndex1 = myIndicesBuffer[thread_inwarp_id];
|
||||
int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride];
|
||||
|
||||
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
|
||||
|
||||
myDataBuffer[thread_inwarp_id] = currVal1;
|
||||
myIndicesBuffer[thread_inwarp_id] = currIndex1;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if(thread_inwarp_id == 0)
|
||||
binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]);
|
||||
};
|
||||
|
||||
// cppcheck-suppress constParameter
|
||||
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
|
||||
|
||||
__all(1);
|
||||
};
|
||||
|
||||
// Execute unary operation on the per-thread buffer elements
|
||||
template <typename unary_op_type>
|
||||
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
|
||||
|
||||
__all(1);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename BufferType,
|
||||
typename IdxBufferType,
|
||||
index_t BlockSize,
|
||||
typename opReduce,
|
||||
NanPropagation_t nanPropaOpt>
|
||||
struct WarpReduceWithIndicesInput
|
||||
{
|
||||
using compType = typename opReduce::dataType;
|
||||
using binop = detail::binop_with_nan_check<nanPropaOpt, opReduce, compType>;
|
||||
|
||||
static_assert(BufferType::IsStaticBuffer(),
|
||||
"Per-thread buffer for WarpWise reduction should be StaticBuffer!");
|
||||
static_assert(IdxBufferType::IsStaticBuffer(),
|
||||
"Per-thread buffer for WarpWise reduction should be StaticBuffer for indices!");
|
||||
|
||||
static_assert(std::is_same<typename BufferType::type, compType>::value,
|
||||
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
|
||||
"the compType!");
|
||||
static_assert(
|
||||
std::is_same<typename IdxBufferType::type, index_t>::value,
|
||||
"Indices type per-thread of StaticBuffer for WarpWise reduction should be index_t!");
|
||||
|
||||
static_assert(BufferType::Size() == IdxBufferType::Size(),
|
||||
"StaticBuffers for data and indices should have the same sizes!");
|
||||
|
||||
static constexpr index_t ThreadBufferLen = BufferType::Size();
|
||||
static constexpr bool have_builtin_shuffle =
|
||||
std::is_same<compType, float>::value || std::is_same<compType, double>::value;
|
||||
|
||||
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
|
||||
// reduction method at second-time reduction
|
||||
__device__ static void Reduce(const BufferType& thread_buffer,
|
||||
const IdxBufferType& thread_indices_buffer,
|
||||
compType& accuData,
|
||||
int& accuIndex)
|
||||
{
|
||||
if constexpr(have_builtin_shuffle)
|
||||
ReduceImpl1(thread_buffer, thread_indices_buffer, accuData, accuIndex);
|
||||
else
|
||||
ReduceImpl2(thread_buffer, thread_indices_buffer, accuData, accuIndex);
|
||||
};
|
||||
|
||||
// This interface implementation uses HIP built-in device shuffling functions
|
||||
__device__ static void ReduceImpl1(const BufferType& thread_buffer,
|
||||
const IdxBufferType& thread_indices_buffer,
|
||||
compType& accuData,
|
||||
int& accuIndex)
|
||||
{
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
|
||||
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]);
|
||||
});
|
||||
|
||||
// synchronize among all threads in this warp
|
||||
__all(1);
|
||||
|
||||
for(index_t stride = 1; stride < warpSize; stride *= 2)
|
||||
{
|
||||
compType tmpVal = __shfl_down(lAccuData, stride, warpSize);
|
||||
int tmpIndex = __shfl_down(lAccuIndex, stride, warpSize);
|
||||
|
||||
binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex);
|
||||
__all(1);
|
||||
}
|
||||
|
||||
binop::calculate(accuData, lAccuData, accuIndex, lAccuIndex);
|
||||
};
|
||||
|
||||
// This interface implementation does not use HIP built-in device shuffling functions
|
||||
// since for fp16, built-in shuffling functions is not provided by HIP
|
||||
__device__ static void ReduceImpl2(const BufferType& thread_buffer,
|
||||
const IdxBufferType& thread_indices_buffer,
|
||||
compType& accuData,
|
||||
int& accuIndex)
|
||||
{
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
index_t warpId = thread_id / warpSize;
|
||||
index_t thread_inwarp_id = thread_id % warpSize;
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
|
||||
binop::calculate(lAccuData, thread_buffer[I], lAccuIndex, thread_indices_buffer[I]);
|
||||
});
|
||||
|
||||
__shared__ compType shuffle_data_buffer[BlockSize];
|
||||
__shared__ int shuffle_indices_buffer[BlockSize];
|
||||
|
||||
compType* myDataBuffer = &shuffle_data_buffer[warpId * warpSize];
|
||||
int* myIndicesBuffer = &shuffle_indices_buffer[warpId * warpSize];
|
||||
|
||||
myDataBuffer[thread_inwarp_id] = lAccuData;
|
||||
myIndicesBuffer[thread_inwarp_id] = lAccuIndex;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t stride = 1; stride < warpSize; stride *= 2)
|
||||
{
|
||||
compType currVal1 = myDataBuffer[thread_inwarp_id];
|
||||
compType currVal2 = myDataBuffer[thread_inwarp_id + stride];
|
||||
int currIndex1 = myIndicesBuffer[thread_inwarp_id];
|
||||
int currIndex2 = myIndicesBuffer[thread_inwarp_id + stride];
|
||||
|
||||
binop::calculate(currVal1, currVal2, currIndex1, currIndex2);
|
||||
|
||||
myDataBuffer[thread_inwarp_id] = currVal1;
|
||||
myIndicesBuffer[thread_inwarp_id] = currIndex1;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if(thread_inwarp_id == 0)
|
||||
binop::calculate(accuData, myDataBuffer[0], accuIndex, myIndicesBuffer[0]);
|
||||
};
|
||||
|
||||
// cppcheck-suppress constParameter
|
||||
__device__ static void set_buffer_value(BufferType& thread_buffer, compType value)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { thread_buffer(I) = value; });
|
||||
|
||||
__all(1);
|
||||
};
|
||||
|
||||
// Execute unary operation on the per-thread buffer elements
|
||||
template <typename unary_op_type>
|
||||
__device__ static void operate_on_elements(unary_op_type& unary_op, BufferType& thread_buffer)
|
||||
{
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { thread_buffer(I) = unary_op(thread_buffer[I]); });
|
||||
|
||||
__all(1);
|
||||
};
|
||||
};
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
#endif
|
||||
16
composable_kernel/include/utility/math_v2.hpp
Normal file
16
composable_kernel/include/utility/math_v2.hpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef CK_MATH_V2_HPP
|
||||
#define CK_MATH_V2_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace math {
|
||||
|
||||
static inline __device__ half_t abs(half_t x) { return __habs(x); };
|
||||
static inline __device__ half_t sqrtf(half_t x) { return hsqrt(x); };
|
||||
static inline __device__ bool isnan(half_t x) { return __hisnan(x); };
|
||||
|
||||
} // namespace math
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -48,6 +48,18 @@ struct float_equal_zero
|
||||
};
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
static constexpr __device__ index_t get_shift()
|
||||
{
|
||||
return (get_shift<N / 2>() + 1);
|
||||
};
|
||||
|
||||
template <>
|
||||
constexpr __device__ index_t get_shift<1>()
|
||||
{
|
||||
return (0);
|
||||
}
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -34,50 +34,79 @@
|
||||
namespace ck {
|
||||
namespace detail {
|
||||
|
||||
static inline __device__ bool isnan(half_t x) { return __hisnan(x); };
|
||||
template <typename T>
|
||||
static inline __device__ bool is_nan(T x)
|
||||
{
|
||||
return (isnan(x));
|
||||
};
|
||||
|
||||
template <NanPropagation_t nanPropaOpt, typename opReduce, typename compType>
|
||||
struct binop_with_nan_check;
|
||||
template <>
|
||||
inline __device__ bool is_nan<half_t>(half_t x)
|
||||
{
|
||||
return (__hisnan(x));
|
||||
};
|
||||
|
||||
template <typename opReduce, typename compType>
|
||||
struct binop_with_nan_check<NanPropagation_t::NOT_PROPAGATE_NAN, opReduce, compType>
|
||||
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck;
|
||||
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
|
||||
{
|
||||
// cppcheck-suppress constParameter
|
||||
__device__ static inline void calculate(compType& accuVal, compType currVal)
|
||||
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
{
|
||||
opReduce{}(accuVal, currVal);
|
||||
ReduceOperation{}(accuVal, currVal);
|
||||
};
|
||||
};
|
||||
|
||||
// The method is called when the opReduce is indexable and the user asked for indices
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
|
||||
{
|
||||
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
{
|
||||
if(is_nan(currVal))
|
||||
{
|
||||
accuVal = currVal;
|
||||
}
|
||||
else
|
||||
{
|
||||
ReduceOperation{}(accuVal, currVal);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
template <bool PropagateNan, typename ReduceOperation, typename AccDataType, typename IndexDataType>
|
||||
struct AccumulateWithIndexAndNanCheck;
|
||||
|
||||
template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
|
||||
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
|
||||
{
|
||||
__device__ static inline void
|
||||
// cppcheck-suppress constParameter
|
||||
calculate(compType& accuVal, compType currVal, int& accuIndex, int currIndex)
|
||||
Calculate(AccDataType& accuVal,
|
||||
AccDataType currVal,
|
||||
IndexDataType& accuIndex,
|
||||
IndexDataType currIndex)
|
||||
{
|
||||
bool changed = false;
|
||||
|
||||
opReduce{}(accuVal, currVal, changed);
|
||||
ReduceOperation{}(accuVal, currVal, changed);
|
||||
|
||||
if(changed)
|
||||
accuIndex = currIndex;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename opReduce, typename compType>
|
||||
struct binop_with_nan_check<NanPropagation_t::PROPAGATE_NAN, opReduce, compType>
|
||||
template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
|
||||
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
|
||||
{
|
||||
__device__ static inline void calculate(compType& accuVal, compType currVal)
|
||||
// The method is called when the ReduceOperation is indexable and the user asked for indices
|
||||
__device__ static inline void Calculate(AccDataType& accuVal,
|
||||
AccDataType currVal,
|
||||
IndexDataType& accuIndex,
|
||||
IndexDataType currIndex)
|
||||
{
|
||||
if(isnan(currVal))
|
||||
accuVal = currVal;
|
||||
else
|
||||
opReduce{}(accuVal, currVal);
|
||||
};
|
||||
|
||||
// The method is called when the opReduce is indexable and the user asked for indices
|
||||
__device__ static inline void
|
||||
calculate(compType& accuVal, compType currVal, int& accuIndex, int currIndex)
|
||||
{
|
||||
if(isnan(currVal))
|
||||
if(is_nan(currVal))
|
||||
{
|
||||
accuVal = currVal;
|
||||
accuIndex = currIndex;
|
||||
@@ -86,7 +115,7 @@ struct binop_with_nan_check<NanPropagation_t::PROPAGATE_NAN, opReduce, compType>
|
||||
{
|
||||
bool changed = false;
|
||||
|
||||
opReduce{}(accuVal, currVal, changed);
|
||||
ReduceOperation{}(accuVal, currVal, changed);
|
||||
|
||||
if(changed)
|
||||
accuIndex = currIndex;
|
||||
@@ -26,7 +26,7 @@
|
||||
#ifndef CK_REDUCTION_OPERATOR_HPP
|
||||
#define CK_REDUCTION_OPERATOR_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -60,11 +60,9 @@ struct Add
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
@@ -72,11 +70,9 @@ struct Mul
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
@@ -84,15 +80,18 @@ struct Max
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::Lowest(); };
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal()
|
||||
{
|
||||
return NumericLimits<T>::Lowest();
|
||||
};
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
if(a < b)
|
||||
a = b;
|
||||
}
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b, bool& changed) const
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
|
||||
{
|
||||
if(a < b)
|
||||
{
|
||||
@@ -100,8 +99,6 @@ struct Max
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
@@ -109,15 +106,18 @@ struct Min
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::Max(); };
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal()
|
||||
{
|
||||
return NumericLimits<T>::Max();
|
||||
};
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
if(a > b)
|
||||
a = b;
|
||||
}
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b, bool& changed) const
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
|
||||
{
|
||||
if(a > b)
|
||||
{
|
||||
@@ -125,8 +125,6 @@ struct Min
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
@@ -134,15 +132,15 @@ struct AMax
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
if(a < b)
|
||||
a = b;
|
||||
}
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b, bool& changed) const
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
|
||||
{
|
||||
if(a < b)
|
||||
{
|
||||
@@ -150,270 +148,10 @@ struct AMax
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
// Unary operators are usually called element-wisely before the reduction is executed on the
|
||||
// elements.
|
||||
// They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
|
||||
template <class T, bool hasDividing>
|
||||
struct unary_identic
|
||||
{
|
||||
__device__ unary_identic(const int divider = 1)
|
||||
{
|
||||
scaler = 1.0f / static_cast<float>(divider);
|
||||
};
|
||||
|
||||
__device__ inline constexpr T operator()(T a) const { return a * type_convert<T>(scaler); };
|
||||
|
||||
float scaler = 1.0f;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct unary_identic<T, false>
|
||||
{
|
||||
__device__ unary_identic(const int divider = 1) { (void)divider; };
|
||||
|
||||
__device__ inline constexpr T operator()(T a) const { return a; };
|
||||
};
|
||||
|
||||
template <class T, bool hasDividing>
|
||||
struct unary_square
|
||||
{
|
||||
__device__ unary_square(const int divider = 1) { scaler = 1.0f / static_cast<float>(divider); };
|
||||
|
||||
__device__ inline constexpr T operator()(T a) const
|
||||
{
|
||||
a = a * a;
|
||||
|
||||
return a * type_convert<T>(scaler);
|
||||
};
|
||||
|
||||
float scaler = 1.0f;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct unary_square<T, false>
|
||||
{
|
||||
__device__ unary_square(const int divider = 1) { (void)divider; };
|
||||
|
||||
__device__ inline constexpr T operator()(T a) const { return a * a; };
|
||||
};
|
||||
|
||||
template <class T, bool hasDividing>
|
||||
struct unary_abs
|
||||
{
|
||||
__device__ unary_abs(const int divider = 1) { scaler = 1.0f / static_cast<float>(divider); };
|
||||
|
||||
__device__ inline constexpr T operator()(T a) const
|
||||
{
|
||||
a = abs(a);
|
||||
|
||||
return a * type_convert<T>(scaler);
|
||||
};
|
||||
|
||||
float scaler = 1.0f;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct unary_abs<T, false>
|
||||
{
|
||||
__device__ unary_abs(const int divider = 1) { (void)divider; };
|
||||
|
||||
__device__ inline constexpr T operator()(T a) const { return abs(a); };
|
||||
};
|
||||
|
||||
// We know for sure that 4.0 has __habs(), but 3.0 does not have it.
|
||||
// Let's assume that __habs() exists since 3.5.
|
||||
#if HIP_PACKAGE_VERSION_FLAT < 3005000000
|
||||
inline __device__ __half __habs(__half x)
|
||||
{
|
||||
union
|
||||
{
|
||||
__half half;
|
||||
unsigned short u16;
|
||||
} val;
|
||||
val.half = x;
|
||||
val.u16 = val.u16 & 0x7fff;
|
||||
return val.half;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <bool hasDividing>
|
||||
struct unary_abs<half_t, hasDividing>
|
||||
{
|
||||
__device__ unary_abs(const int divider = 1) { scaler = 1.0f / static_cast<float>(divider); };
|
||||
|
||||
__device__ inline half_t operator()(half_t a) const
|
||||
{
|
||||
a = static_cast<half_t>(__habs(a));
|
||||
|
||||
return a * type_convert<half_t>(scaler);
|
||||
};
|
||||
|
||||
float scaler = 1.0f;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unary_abs<half_t, false>
|
||||
{
|
||||
__device__ unary_abs(const int divider = 1) { (void)divider; };
|
||||
|
||||
__device__ inline half_t operator()(half_t a) const { return static_cast<half_t>(__habs(a)); };
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct unary_sqrt
|
||||
{
|
||||
__device__ unary_sqrt(const int divider = 1) { (void)divider; };
|
||||
|
||||
__device__ inline T operator()(T a) const { return sqrtf(a); };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unary_sqrt<half_t>
|
||||
{
|
||||
__device__ unary_sqrt(const int divider = 1) { (void)divider; };
|
||||
|
||||
__device__ inline half_t operator()(half_t a) const { return static_cast<half_t>(hsqrt(a)); };
|
||||
};
|
||||
|
||||
}; // end of namespace reduce
|
||||
|
||||
// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their
|
||||
// respective functor classes.
|
||||
// The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in
|
||||
// reduce_binary_operactor for
|
||||
// easier checking by the upper-layer codes in the kernels.
|
||||
|
||||
template <typename T, ReduceTensorOp_t op>
|
||||
struct reduce_binary_operator;
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
|
||||
{
|
||||
using opType = reduce::Mul<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Mul<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
|
||||
{
|
||||
using opType = reduce::Min<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Min<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
|
||||
{
|
||||
using opType = reduce::Max<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Max<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
|
||||
{
|
||||
using opType = reduce::AMax<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Max<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
|
||||
// functor classes.
|
||||
// The two unary functors are called before and afer the Reduction is executed respectively
|
||||
template <typename T, ReduceTensorOp_t op, bool isFirsReduce, bool isLastReduce>
|
||||
struct reduce_unary_operator
|
||||
{
|
||||
using preUnaryOp = reduce::unary_identic<T, false>;
|
||||
using posUnaryOp = reduce::unary_identic<T, false>;
|
||||
};
|
||||
|
||||
template <typename T, bool isFirstReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::AVG, isFirstReduce, true>
|
||||
{
|
||||
using preUnaryOp = reduce::unary_identic<T, false>;
|
||||
using posUnaryOp = reduce::unary_identic<T, true>;
|
||||
};
|
||||
|
||||
template <typename T, bool isLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM1, true, isLastReduce>
|
||||
{
|
||||
using preUnaryOp = reduce::unary_abs<T, false>;
|
||||
using posUnaryOp = reduce::unary_identic<T, false>;
|
||||
};
|
||||
|
||||
template <typename T, bool isLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::AMAX, true, isLastReduce>
|
||||
{
|
||||
using preUnaryOp = reduce::unary_abs<T, false>;
|
||||
using posUnaryOp = reduce::unary_identic<T, false>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, false>
|
||||
{
|
||||
using preUnaryOp = reduce::unary_square<T, false>;
|
||||
using posUnaryOp = reduce::unary_identic<T, false>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, true>
|
||||
{
|
||||
using preUnaryOp = reduce::unary_square<T, false>;
|
||||
using posUnaryOp = reduce::unary_sqrt<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, false, true>
|
||||
{
|
||||
using preUnaryOp = reduce::unary_identic<T, false>;
|
||||
using posUnaryOp = reduce::unary_sqrt<T>;
|
||||
};
|
||||
|
||||
} // end of namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_blockwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
auto src2dDesc = transform_tensor_descriptor(
|
||||
one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pass_through_transform(invariantLen),
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
};
|
||||
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
transform_tensor_descriptor(ref_one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the BlockWise and MultiBlock method
|
||||
using refType_src2dDesc_padded_34 = decltype(
|
||||
transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pass_through_transform(ref_invariantLen),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
true,
|
||||
true,
|
||||
GredAccessesPerThreadInBlock>;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(nullptr),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,305 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_blockwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(srcLengths, invariantDims{});
|
||||
|
||||
auto src2dDesc =
|
||||
transform_tensor_descriptor(srcDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pass_through_transform(invariantLen),
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_toReduceDimLengths =
|
||||
typename uniform_sequence_gen<toReduceDims::Size(), 8>::type{};
|
||||
static constexpr auto ref_invariantDimLengths =
|
||||
typename uniform_sequence_gen<invariantDims::Size(), 8>::type{};
|
||||
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
|
||||
static constexpr auto ref_src2dDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)),
|
||||
make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the BlockWise and MultiBlock method
|
||||
using refType_src2dDesc_padded_34 = decltype(
|
||||
transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pass_through_transform(ref_invariantLen),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
true,
|
||||
true,
|
||||
GredAccessesPerThreadInBlock>;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(nullptr),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,276 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_multiblock.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
auto src2dDesc = transform_tensor_descriptor(
|
||||
one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
const index_t reduceSizePerBlock =
|
||||
(((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) *
|
||||
copySliceLen;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pass_through_transform(invariantLen),
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
};
|
||||
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
transform_tensor_descriptor(ref_one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the BlockWise and MultiBlock method
|
||||
using refType_src2dDesc_padded_34 = decltype(
|
||||
transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pass_through_transform(ref_invariantLen),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_dst_global;
|
||||
(void)indices_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
GredAccessesPerThreadInBlock>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
BlkGroupSize,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<srcDataType* const __restrict__>(ws_buf1_global),
|
||||
static_cast<int* const __restrict__>(ws_buf2_global));
|
||||
};
|
||||
@@ -1,310 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_multiblock.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(srcLengths, invariantDims{});
|
||||
|
||||
auto src2dDesc =
|
||||
transform_tensor_descriptor(srcDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
const index_t reduceSizePerBlock =
|
||||
(((toReduceLen + BlkGroupSize - 1) / BlkGroupSize + copySliceLen - 1) / copySliceLen) *
|
||||
copySliceLen;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad = reduceSizePerBlock * BlkGroupSize - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pass_through_transform(invariantLen),
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_toReduceDimLengths =
|
||||
typename uniform_sequence_gen<toReduceDims::Size(), 8>::type{};
|
||||
static constexpr auto ref_invariantDimLengths =
|
||||
typename uniform_sequence_gen<invariantDims::Size(), 8>::type{};
|
||||
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
|
||||
static constexpr auto ref_src2dDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)),
|
||||
make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the BlockWise and MultiBlock method
|
||||
using refType_src2dDesc_padded_34 = decltype(
|
||||
transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pass_through_transform(ref_invariantLen),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_dst_global;
|
||||
(void)indices_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_multiblock<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
GredAccessesPerThreadInBlock>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
BlkGroupSize,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<srcDataType* const __restrict__>(ws_buf1_global),
|
||||
static_cast<int* const __restrict__>(ws_buf2_global));
|
||||
};
|
||||
@@ -1,284 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
auto src2dDesc = transform_tensor_descriptor(
|
||||
one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = GredThreadBufferLength;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dstdDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
transform_tensor_descriptor(ref_one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
true,
|
||||
true,
|
||||
GredThreadBufferLength>;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(nullptr),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,318 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(srcLengths, invariantDims{});
|
||||
|
||||
auto src2dDesc =
|
||||
transform_tensor_descriptor(srcDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = GredThreadBufferLength;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dst1dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_toReduceDimLengths =
|
||||
typename uniform_sequence_gen<toReduceDims::Size(), 8>::type{};
|
||||
static constexpr auto ref_invariantDimLengths =
|
||||
typename uniform_sequence_gen<invariantDims::Size(), 8>::type{};
|
||||
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
|
||||
static constexpr auto ref_src2dDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)),
|
||||
make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
true,
|
||||
true,
|
||||
GredThreadBufferLength>;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(nullptr),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,285 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
auto src2dDesc = transform_tensor_descriptor(
|
||||
one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(1, one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dstDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_srcLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
transform_tensor_descriptor(ref_one_dim_srcDesc,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(1, ref_one_dim_srcDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce =
|
||||
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
true,
|
||||
true,
|
||||
GredAccessesPerThreadInWarp>;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(nullptr),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,320 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int inLength0,
|
||||
int inLength1,
|
||||
int inLength2,
|
||||
int inLength3,
|
||||
int inLength4,
|
||||
int inLength5,
|
||||
int inStride0,
|
||||
int inStride1,
|
||||
int inStride2,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(srcLengths, toReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(srcLengths, invariantDims{});
|
||||
|
||||
auto src2dDesc =
|
||||
transform_tensor_descriptor(srcDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dst1dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_toReduceDimLengths =
|
||||
typename uniform_sequence_gen<toReduceDims::Size(), 8>::type{};
|
||||
static constexpr auto ref_invariantDimLengths =
|
||||
typename uniform_sequence_gen<invariantDims::Size(), 8>::type{};
|
||||
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 8>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
|
||||
static constexpr auto ref_src2dDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_invariantDimLengths)),
|
||||
make_merge_transform(make_tuple_from_seq(ref_toReduceDimLengths))),
|
||||
make_tuple(invariantDims{}, toReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
int BlkGroupSize,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce =
|
||||
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
true,
|
||||
true,
|
||||
GredAccessesPerThreadInWarp>;
|
||||
|
||||
constexpr int RunId = need_indices ? 2 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(p_src_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(nullptr),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,205 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_blockwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
|
||||
|
||||
extern "C" __global__ void
|
||||
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pass_through_transform(invariantLen),
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
};
|
||||
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths = make_tuple(8);
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
|
||||
// used by the BlockWise and MultiBlock method
|
||||
using refType_src2dDesc_padded_34 = decltype(
|
||||
transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pass_through_transform(ref_invariantLen),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredAccessesPerThreadInBlock>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,263 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_blockwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const index_t invariantLen = dst1dDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pass_through_transform(invariantLen),
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
};
|
||||
|
||||
template <index_t dstDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths =
|
||||
make_tuple_from_seq(typename uniform_sequence_gen<dstDims, 8>::type{});
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(ref_tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
|
||||
// used by the BlockWise and MultiBlock method
|
||||
using refType_src2dDesc_padded_34 = decltype(
|
||||
transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pass_through_transform(ref_invariantLen),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredAccessesPerThreadInBlock>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,222 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
|
||||
|
||||
extern "C" __global__ void
|
||||
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = GredThreadBufferLength;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dstDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths = make_tuple(8);
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredThreadBufferLength>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,277 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const index_t invariantLen = dst1dDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = GredThreadBufferLength;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dst1dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t dstDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths =
|
||||
make_tuple_from_seq(typename uniform_sequence_gen<dstDims, 8>::type{});
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(ref_tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredThreadBufferLength>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,221 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
|
||||
|
||||
extern "C" __global__ void
|
||||
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dstDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths = make_tuple(8);
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce =
|
||||
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredAccessesPerThreadInWarp>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -1,279 +0,0 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
__device__ static auto make_tuple_from_array_and_index_seq(const int* lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
__device__ static auto make_tuple_from_array(const int* lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
template <index_t... Ns>
|
||||
__device__ static constexpr auto make_tuple_from_seq(Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(Ns...);
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
int BlkGroupSize,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const index_t invariantLen = dst1dDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dst1dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t dstDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths =
|
||||
make_tuple_from_seq(typename uniform_sequence_gen<dstDims, 8>::type{});
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(ref_tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dst1dDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce =
|
||||
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredAccessesPerThreadInWarp>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -111,7 +111,35 @@ set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
|
||||
)
|
||||
|
||||
# device_reduce_instance
|
||||
set(DEVICE_REDUCE_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f64_f64_f64.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f64_f64_f64.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp;
|
||||
)
|
||||
|
||||
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
|
||||
@@ -120,8 +148,8 @@ add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURC
|
||||
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_bwd_data_instance SHARED ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE})
|
||||
add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE})
|
||||
|
||||
target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
@@ -134,6 +162,7 @@ target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<
|
||||
target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_bwd_data_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_reduce_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
|
||||
target_compile_features(device_gemm_instance PUBLIC)
|
||||
target_compile_features(device_gemm_bias_2d_instance PUBLIC)
|
||||
@@ -146,6 +175,7 @@ target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_bwd_data_instance PUBLIC)
|
||||
target_compile_features(device_reduce_instance PUBLIC)
|
||||
|
||||
set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
@@ -158,6 +188,7 @@ set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_I
|
||||
set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
|
||||
@@ -170,3 +201,4 @@ install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_reduce_instance LIBRARY DESTINATION lib)
|
||||
|
||||
@@ -549,8 +549,11 @@ struct
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
@@ -625,8 +628,11 @@ struct
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> input_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> conv_filter_dilations_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
@@ -638,6 +644,28 @@ struct
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
<< "K " << arg.Conv_K_ << ", "
|
||||
<< "C " << arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
|
||||
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
|
||||
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
|
||||
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
|
||||
<< arg.conv_filter_strides_[1] << ", " << std::endl;
|
||||
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
|
||||
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -656,6 +684,7 @@ struct
|
||||
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -526,8 +526,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
@@ -590,8 +593,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> input_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> conv_filter_dilations_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
@@ -603,6 +609,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
<< "K " << arg.Conv_K_ << ", "
|
||||
<< "C " << arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
|
||||
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
|
||||
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
|
||||
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
|
||||
<< arg.conv_filter_strides_[1] << ", " << std::endl;
|
||||
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
|
||||
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -618,6 +646,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -498,8 +498,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
@@ -551,8 +554,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> input_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> conv_filter_dilations_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
@@ -564,6 +570,28 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
<< "K " << arg.Conv_K_ << ", "
|
||||
<< "C " << arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
|
||||
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
|
||||
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
|
||||
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
|
||||
<< arg.conv_filter_strides_[1] << ", " << std::endl;
|
||||
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
|
||||
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -598,6 +626,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
.GetLength(I5)
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -452,6 +452,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -464,6 +465,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
38
device_operation/include/device_pool2d_fwd.hpp
Normal file
38
device_operation/include/device_pool2d_fwd.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
#ifndef DEVICE_POOL2D_FWD_HPP
|
||||
#define DEVICE_POOL2D_FWD_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include "device_base.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::ReduceTensorOp_t ReduceOpId>
|
||||
struct DevicePool2dFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, 2> input_spatial_lengths,
|
||||
std::array<ck::index_t, 2> window_spatial_lengths,
|
||||
std::array<ck::index_t, 2> output_spatial_lengths,
|
||||
std::array<ck::index_t, 2> window_strides,
|
||||
std::array<ck::index_t, 2> input_left_pads,
|
||||
std::array<ck::index_t, 2> input_right_pads) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <ck::ReduceTensorOp_t ReduceOpId>
|
||||
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
327
device_operation/include/device_pool2d_fwd_nhwc_nhwc.hpp
Normal file
327
device_operation/include/device_pool2d_fwd_nhwc_nhwc.hpp
Normal file
@@ -0,0 +1,327 @@
|
||||
#ifndef DEVICE_POOL2D_FWD_NHWC_NHWC_HPP
|
||||
#define DEVICE_POOL2D_FWD_NHWC_NHWC_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_pool2d_fwd.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "gridwise_2d_reduction_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
ck::ReduceTensorOp_t ReduceOpId,
|
||||
bool NeedIndices,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t ReduceMThreadClusterSize,
|
||||
ck::index_t ReduceKThreadClusterSize,
|
||||
ck::index_t ReduceMThreadSliceSize,
|
||||
ck::index_t ReduceKThreadSliceSize,
|
||||
ck::index_t InSrcOutDstVectorSize>
|
||||
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd<ReduceOpId>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
|
||||
static constexpr bool BetaIsZero = true;
|
||||
|
||||
static constexpr index_t InSrcOutDstVectorDim =
|
||||
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
|
||||
// not reduced.
|
||||
|
||||
static constexpr ck::index_t ReduceM_BlockTileSize =
|
||||
ReduceMThreadClusterSize * ReduceMThreadSliceSize;
|
||||
static constexpr ck::index_t ReduceK_BlockTileSize =
|
||||
ReduceKThreadClusterSize * ReduceKThreadSliceSize;
|
||||
|
||||
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, 2> input_spatial_lengths,
|
||||
std::array<ck::index_t, 2> window_spatial_lengths,
|
||||
std::array<ck::index_t, 2> output_spatial_lengths,
|
||||
std::array<ck::index_t, 2> window_strides,
|
||||
std::array<ck::index_t, 2> input_left_pads,
|
||||
std::array<ck::index_t, 2> input_right_pads)
|
||||
{
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = window_spatial_lengths[0];
|
||||
const index_t X = window_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = window_strides[0];
|
||||
const index_t ConvStrideW = window_strides[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t ReduceMRaw = N * Ho * Wo * C;
|
||||
const index_t ReduceMPad =
|
||||
math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw;
|
||||
|
||||
const index_t ReduceKRaw = Y * X;
|
||||
const index_t ReduceKPad =
|
||||
math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw;
|
||||
|
||||
// A[ReduceM, ReduceK]
|
||||
const auto in_grid_desc_n_hi_wi_c =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
|
||||
in_grid_desc_n_hi_wi_c,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
|
||||
in_grid_desc_n_hip_wip_c,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_grid_desc_reducemraw_reducekraw =
|
||||
transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
|
||||
make_merge_transform(make_tuple(Y, X))),
|
||||
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
|
||||
in_grid_desc_reducemraw_reducekraw,
|
||||
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad),
|
||||
make_right_pad_transform(ReduceKRaw, ReduceKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B[ReduceM]
|
||||
const auto out_grid_desc_reducemraw =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C));
|
||||
|
||||
const auto out_grid_desc_reducem = transform_tensor_descriptor(
|
||||
out_grid_desc_reducemraw,
|
||||
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
|
||||
}
|
||||
|
||||
using ABGridDescs = decltype(
|
||||
MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
|
||||
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
|
||||
|
||||
// TODO
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_dev,
|
||||
OutDataType* p_out_dev,
|
||||
int* p_out_indices_dev,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, 2>& input_spatial_lengths,
|
||||
std::array<ck::index_t, 2>& window_spatial_lengths,
|
||||
std::array<ck::index_t, 2>& output_spatial_lengths,
|
||||
std::array<ck::index_t, 2>& window_strides,
|
||||
std::array<ck::index_t, 2>& input_left_pads,
|
||||
std::array<ck::index_t, 2>& input_right_pads)
|
||||
: p_in_dev_{p_in_dev},
|
||||
p_out_dev_{p_out_dev},
|
||||
p_out_indices_dev_{p_out_indices_dev},
|
||||
a_grid_desc_m_k_{},
|
||||
b_grid_desc_m_{}
|
||||
{
|
||||
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
window_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_m_k_ = descs[I0];
|
||||
b_grid_desc_m_ = descs[I1];
|
||||
|
||||
invariant_lowest_length_ = C;
|
||||
reduce_lowest_length_ = window_spatial_lengths[1];
|
||||
|
||||
// TODO: is this correct?
|
||||
if constexpr(ReduceOpId == ck::ReduceTensorOp_t::AVG)
|
||||
{
|
||||
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
in_element_op_ = InElementwiseOperation{divider};
|
||||
acc_element_op_ = AccElementwiseOperation{divider};
|
||||
}
|
||||
}
|
||||
|
||||
const InDataType* p_in_dev_;
|
||||
OutDataType* p_out_dev_;
|
||||
int* p_out_indices_dev_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
|
||||
// for checking vector load/store
|
||||
ck::index_t invariant_lowest_length_;
|
||||
ck::index_t reduce_lowest_length_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
false, // propagate_nan
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
ReduceMThreadClusterSize,
|
||||
ReduceKThreadClusterSize,
|
||||
ReduceMThreadSliceSize,
|
||||
ReduceKThreadSliceSize,
|
||||
InSrcOutDstVectorDim,
|
||||
InSrcOutDstVectorSize,
|
||||
InSrcOutDstVectorSize>;
|
||||
|
||||
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
|
||||
|
||||
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
|
||||
|
||||
return launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_m_,
|
||||
arg.in_element_op_,
|
||||
arg.acc_element_op_,
|
||||
float(1),
|
||||
arg.p_in_dev_,
|
||||
float(0),
|
||||
arg.p_out_dev_,
|
||||
arg.p_out_indices_dev_);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
|
||||
return (true);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_dev,
|
||||
void* p_out_dev,
|
||||
void* p_out_indices_dev,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, 2> input_spatial_lengths,
|
||||
std::array<ck::index_t, 2> window_spatial_lengths,
|
||||
std::array<ck::index_t, 2> output_spatial_lengths,
|
||||
std::array<ck::index_t, 2> window_strides,
|
||||
std::array<ck::index_t, 2> input_left_pads,
|
||||
std::array<ck::index_t, 2> input_right_pads) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
|
||||
static_cast<OutDataType*>(p_out_dev),
|
||||
static_cast<int*>(p_out_indices_dev),
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
window_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ",";
|
||||
str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ",";
|
||||
str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ",";
|
||||
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
58
device_operation/include/device_reduce.hpp
Normal file
58
device_operation/include/device_reduce.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
#ifndef DEVICE_REDUCE_HPP
|
||||
#define DEVICE_REDUCE_HPP
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation, typename AccElementwiseOperation>
|
||||
struct DeviceReduce : public BaseOperator
|
||||
{
|
||||
virtual size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths)
|
||||
{
|
||||
(void)inLengths;
|
||||
|
||||
return (0);
|
||||
};
|
||||
|
||||
virtual bool HasFurtherCall() { return (false); };
|
||||
|
||||
virtual std::vector<int> GetWorkspace2dLengths(const BaseArgument* argPtr)
|
||||
{
|
||||
(void)argPtr;
|
||||
return (std::vector<int>{0, 0});
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& inElementwiseOp,
|
||||
const AccElementwiseOperation& accElementwiseOp) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation, typename AccElementwiseOperation>
|
||||
using DeviceReducePtr =
|
||||
std::unique_ptr<DeviceReduce<InElementwiseOperation, AccElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
354
device_operation/include/device_reduce_blockwise.hpp
Normal file
354
device_operation/include/device_reduce_blockwise.hpp
Normal file
@@ -0,0 +1,354 @@
|
||||
#ifndef DEVICE_REDUCE_BLOCKWISE_HPP
|
||||
#define DEVICE_REDUCE_BLOCKWISE_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto inPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
const auto in_grid_desc_m_k =
|
||||
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
|
||||
const auto out_grid_desc_m =
|
||||
DeviceReduceBlockWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_reduce_blockwise<GridwiseReduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_,
|
||||
nullptr,
|
||||
arg.out_indices_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// cases with very small reduce_total_length should be handled by the ThreadWise method
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceBlockWise<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
317
device_operation/include/device_reduce_blockwise_second_call.hpp
Normal file
317
device_operation/include/device_reduce_blockwise_second_call.hpp
Normal file
@@ -0,0 +1,317 @@
|
||||
#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
|
||||
#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceBlockWiseSecondCall
|
||||
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
static_assert(
|
||||
std::is_same<InDataType, AccDataType>::value,
|
||||
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{});
|
||||
|
||||
const auto in_grid_desc_m_k =
|
||||
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
{
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
invariant_total_length = inLengths[0];
|
||||
reduce_total_length = inLengths[1];
|
||||
|
||||
invariant_lowest_length = inLengths[0];
|
||||
reduce_lowest_length = inLengths[1];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
|
||||
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
|
||||
invariant_total_length * reduce_total_length * sizeof(AccDataType), 64);
|
||||
|
||||
if constexpr(NeedIndices)
|
||||
workspace_indices_dev_ = reinterpret_cast<index_t*>(
|
||||
reinterpret_cast<char*>(workspace_dev) + ws_buf2_bytes_offset);
|
||||
else
|
||||
workspace_indices_dev_ = nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
IndexDataType* workspace_indices_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_);
|
||||
const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor(
|
||||
arg.outLengths_, arg.outStrides_);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_reduce_blockwise_second_call<GridwiseReduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_,
|
||||
arg.workspace_indices_dev_,
|
||||
arg.out_indices_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// cases with very small reduce_total_length should be handled by the ThreadWise method
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
81
device_operation/include/device_reduce_common.hpp
Normal file
81
device_operation/include/device_reduce_common.hpp
Normal file
@@ -0,0 +1,81 @@
|
||||
#ifndef DEVICE_REDUCE_COMMON_HPP
|
||||
#define DEVICE_REDUCE_COMMON_HPP
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// template <typename preUnaryOpType, typename posUnaryOpType>
|
||||
// using DeviceReducePtr = std::unique_ptr<DeviceReduce<preUnaryOpType, posUnaryOpType>>;
|
||||
|
||||
template <int Rank, typename ReduceDims>
|
||||
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
|
||||
{
|
||||
static_assert(Rank <= 6, "bigger Rank size not supported!");
|
||||
|
||||
size_t tensor_total_length = 1;
|
||||
size_t reduce_total_length = 1;
|
||||
|
||||
static_for<0, ReduceDims::Size(), 1>{}(
|
||||
[&](auto i) { reduce_total_length *= inLengths[ReduceDims::At(i)]; });
|
||||
|
||||
static_for<0, Rank, 1>{}([&](auto i) { tensor_total_length *= inLengths[i.value]; });
|
||||
|
||||
return std::make_pair(tensor_total_length / reduce_total_length, reduce_total_length);
|
||||
};
|
||||
|
||||
template <int x, typename Seq>
|
||||
constexpr bool belong()
|
||||
{
|
||||
bool inside = false;
|
||||
|
||||
static_for<0, Seq::Size(), 1>{}([&](auto i) { inside = (inside || (x == Seq::At(i))); });
|
||||
|
||||
return (inside);
|
||||
};
|
||||
|
||||
template <int Rank, typename ReduceDims, int start = 0>
|
||||
constexpr auto get_invariant_dims()
|
||||
{
|
||||
static_assert(Rank <= 6, "bigger Rank size not supported!");
|
||||
|
||||
if constexpr(start >= Rank)
|
||||
return Sequence<>{};
|
||||
else
|
||||
{
|
||||
if constexpr(!belong<start, ReduceDims>())
|
||||
return merge_sequences(Sequence<start>{},
|
||||
get_invariant_dims<Rank, ReduceDims, start + 1>());
|
||||
else
|
||||
return get_invariant_dims<Rank, ReduceDims, start + 1>();
|
||||
};
|
||||
};
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
static auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>)
|
||||
{
|
||||
return make_tuple(static_cast<index_t>(lengths[Ns])...);
|
||||
};
|
||||
|
||||
template <index_t arraySize>
|
||||
static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arraySize>)
|
||||
{
|
||||
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
|
||||
|
||||
constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
|
||||
|
||||
return make_tuple_from_array_and_index_seq(lengths, index_seq);
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
28
device_operation/include/device_reduce_instance.hpp
Normal file
28
device_operation/include/device_reduce_instance.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANTCE_HPP
|
||||
#define DEVICE_REDUCE_INSTANTCE_HPP
|
||||
|
||||
#include "device_reduce_instance_blockwise_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_f16_f32_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
|
||||
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_threadwise_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_threadwise_f64_f64_f64.hpp"
|
||||
|
||||
#endif
|
||||
168
device_operation/include/device_reduce_instance_blockwise.hpp
Normal file
168
device_operation/include/device_reduce_instance_blockwise.hpp
Normal file
@@ -0,0 +1,168 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
#ifdef QUICK_REDUCE_TEST
|
||||
using reduce_configuration_2_instances_blockwise = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 2, 1>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
>;
|
||||
#else
|
||||
using reduce_configuration_2_instances_blockwise = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 4, 4, 8, 1>,
|
||||
ReductionConfiguration_2<0, 4, 4, 4, 1>,
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 8>,
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 4>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
|
||||
// special instances
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 5, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 7, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 11, 1>,
|
||||
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 5>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 7>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 11>
|
||||
// clang-format on
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
using deviceReduceBlockWisePtrType = DeviceReducePtr<
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
void add_device_reduce_instance_blockwise(
|
||||
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
|
||||
ReduceOpId == ReduceTensorOp_t::AMAX);
|
||||
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
|
||||
using cfg1 =
|
||||
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
|
||||
[&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceBlockWise<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
NeedIndices,
|
||||
cfg1::BlockSize_,
|
||||
cfg1::MThreadClusterSize_,
|
||||
cfg1::KThreadClusterSize_,
|
||||
cfg2::MThreadSliceSize_,
|
||||
cfg2::KThreadSliceSize_,
|
||||
cfg2::InSrcVectorDim_,
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(
|
||||
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_BLOCKWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_blockwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
extern template void add_device_reduce_instance_blockwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<DeviceReducePtr< \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,41 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,167 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
#ifdef QUICK_REDUCE_TEST
|
||||
using reduce_configuration_2_instances_blockwise_second_call = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>,
|
||||
ReductionConfiguration_2<1, 1, 2, 1, 3>
|
||||
// clang-format on
|
||||
>;
|
||||
#else
|
||||
using reduce_configuration_2_instances_blockwise_second_call = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 8>,
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 4>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 5>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 7>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 11>
|
||||
// clang-format on
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
using deviceReduceBlockWiseSecondCallPtrType = DeviceReducePtr<
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
void add_device_reduce_instance_blockwise_second_call(
|
||||
std::vector<deviceReduceBlockWiseSecondCallPtrType<AccDataType, ReduceOpId>>&
|
||||
device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::
|
||||
InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::
|
||||
AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
|
||||
ReduceOpId == ReduceTensorOp_t::AMAX);
|
||||
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
|
||||
|
||||
static_assert(std::is_same<InDataType, AccDataType>::value,
|
||||
"InDataType and AccDataType should be the same to use "
|
||||
"add_device_reduce_instance_blockwise_second_call!");
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
|
||||
using cfg1 =
|
||||
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
|
||||
|
||||
static_for<0,
|
||||
std::tuple_size<reduce_configuration_2_instances_blockwise_second_call>::value,
|
||||
1>{}([&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_blockwise_second_call{}))>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceBlockWiseSecondCall<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
NeedIndices,
|
||||
cfg1::BlockSize_,
|
||||
cfg1::MThreadClusterSize_,
|
||||
cfg1::KThreadClusterSize_,
|
||||
cfg2::MThreadSliceSize_,
|
||||
cfg2::KThreadSliceSize_,
|
||||
cfg2::InSrcVectorDim_,
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_blockwise_second_call<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
extern template void add_device_reduce_instance_blockwise_second_call<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector< \
|
||||
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
|
||||
InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,41 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,55 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
template <int BlockSize, int MThreadClusterSize, int KThreadClusterSize>
|
||||
struct ReductionConfiguration_1
|
||||
{
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid Configuration!");
|
||||
|
||||
static constexpr int BlockSize_ = BlockSize;
|
||||
static constexpr int MThreadClusterSize_ = MThreadClusterSize;
|
||||
static constexpr int KThreadClusterSize_ = KThreadClusterSize;
|
||||
};
|
||||
|
||||
template <int InSrcVectorDim,
|
||||
int InSrcVectorSize,
|
||||
int OutDstVectorSize,
|
||||
int MThreadSliceSize,
|
||||
int KThreadSliceSize>
|
||||
struct ReductionConfiguration_2
|
||||
{
|
||||
static constexpr int InSrcVectorDim_ = InSrcVectorDim;
|
||||
static constexpr int InSrcVectorSize_ = InSrcVectorSize;
|
||||
static constexpr int OutDstVectorSize_ = OutDstVectorSize;
|
||||
static constexpr int MThreadSliceSize_ = MThreadSliceSize;
|
||||
static constexpr int KThreadSliceSize_ = KThreadSliceSize;
|
||||
};
|
||||
|
||||
using reduce_configuration_1_instances = std::tuple<
|
||||
// clang-format off
|
||||
// BlockSize | MThreadClusterSize | KThreadClusterSize
|
||||
ReductionConfiguration_1<256, 128, 2>,
|
||||
ReductionConfiguration_1<256, 64, 4>,
|
||||
ReductionConfiguration_1<256, 32, 8>,
|
||||
ReductionConfiguration_1<256, 16, 16>,
|
||||
ReductionConfiguration_1<256, 8, 32>,
|
||||
ReductionConfiguration_1<256, 4, 64>,
|
||||
ReductionConfiguration_1<256, 2, 128>,
|
||||
ReductionConfiguration_1<256, 1, 256>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
#define QUICK_REDUCE_TEST 1
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,192 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
#ifdef QUICK_REDUCE_TEST
|
||||
using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 2, 1>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
>;
|
||||
#else
|
||||
using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 4, 4, 8, 1>,
|
||||
ReductionConfiguration_2<0, 4, 4, 4, 1>,
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 8>,
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 4>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
|
||||
// special instances
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 5, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 7, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 11, 1>,
|
||||
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 5>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 7>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 11>
|
||||
// clang-format on
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOperation>
|
||||
using deviceReduceMultiBlockAtomicAddPtrType =
|
||||
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
|
||||
InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
|
||||
AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
void add_device_reduce_instance_multiblock_atomic_add(
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>&
|
||||
device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
|
||||
ReduceOpId == ReduceTensorOp_t::AMAX);
|
||||
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
|
||||
|
||||
static_assert(IndicesOpt == ReduceTensorIndices_t::NO_INDICES,
|
||||
"AtomicAdd can only be used with reduction operations without indices!");
|
||||
|
||||
constexpr bool op_acceptable =
|
||||
(ReduceOpId == ReduceTensorOp_t::ADD || ReduceOpId == ReduceTensorOp_t::MUL ||
|
||||
ReduceOpId == ReduceTensorOp_t::AVG || ReduceOpId == ReduceTensorOp_t::NORM1);
|
||||
|
||||
constexpr bool out_type_acceptable =
|
||||
(std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value);
|
||||
|
||||
if constexpr(!op_acceptable || !out_type_acceptable)
|
||||
return;
|
||||
else
|
||||
{
|
||||
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
|
||||
using cfg1 =
|
||||
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
|
||||
|
||||
static_for<
|
||||
0,
|
||||
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
|
||||
1>{}([&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_multiblock_atomic_add{}))>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceMultiBlockAtomicAdd<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
NeedIndices,
|
||||
cfg1::BlockSize_,
|
||||
cfg1::MThreadClusterSize_,
|
||||
cfg1::KThreadClusterSize_,
|
||||
cfg2::MThreadSliceSize_,
|
||||
cfg2::KThreadSliceSize_,
|
||||
cfg2::InSrcVectorDim_,
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(
|
||||
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_multiblock_atomic_add<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<DeviceReducePtr< \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,29 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,29 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,29 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,175 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
#ifdef QUICK_REDUCE_TEST
|
||||
using reduce_configuration_2_instances_multiblock_partial_reduce = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 1, 1, 2, 1>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
>;
|
||||
#else
|
||||
using reduce_configuration_2_instances_multiblock_partial_reduce = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 4, 1, 8, 1>,
|
||||
ReductionConfiguration_2<0, 4, 1, 4, 1>,
|
||||
ReductionConfiguration_2<0, 2, 1, 2, 1>,
|
||||
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 8>,
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 4>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
|
||||
// special instances
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 5, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 7, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 11, 1>,
|
||||
|
||||
ReductionConfiguration_2<0, 1, 1, 1, 3>,
|
||||
ReductionConfiguration_2<0, 1, 1, 1, 5>,
|
||||
ReductionConfiguration_2<0, 1, 1, 1, 7>,
|
||||
ReductionConfiguration_2<0, 1, 1, 1, 11>
|
||||
// clang-format on
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
using deviceReduceMultiBlockPartialReducePtrType = DeviceReducePtr<
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
void add_device_reduce_instance_multiblock_partial_reduce(
|
||||
std::vector<deviceReduceMultiBlockPartialReducePtrType<AccDataType, ReduceOpId>>&
|
||||
device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::
|
||||
InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::
|
||||
AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
|
||||
ReduceOpId == ReduceTensorOp_t::AMAX);
|
||||
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
|
||||
using cfg1 =
|
||||
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
|
||||
|
||||
static_for<
|
||||
0,
|
||||
std::tuple_size<reduce_configuration_2_instances_multiblock_partial_reduce>::value,
|
||||
1>{}([&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_multiblock_partial_reduce{}))>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceMultiBlockPartialReduce<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
NeedIndices,
|
||||
cfg1::BlockSize_,
|
||||
cfg1::MThreadClusterSize_,
|
||||
cfg1::KThreadClusterSize_,
|
||||
cfg2::MThreadSliceSize_,
|
||||
cfg2::KThreadSliceSize_,
|
||||
cfg2::InSrcVectorDim_,
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
extern template void \
|
||||
add_device_reduce_instance_multiblock_partial_reduce<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector< \
|
||||
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
|
||||
InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,41 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F16_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F16_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F32_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F32_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,45 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F32_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,26 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F64_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,53 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F64_F64_F64_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F64_F64_F64_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
|
||||
// Will be moved to use MultiBlockAtomicAdd
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
164
device_operation/include/device_reduce_instance_threadwise.hpp
Normal file
164
device_operation/include/device_reduce_instance_threadwise.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_HPP
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
#ifdef QUICK_REDUCE_TEST
|
||||
using reduce_configuration_2_instances_threadwise = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 2, 1>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
ReductionConfiguration_2<1, 2, 2, 1, 2>,
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>
|
||||
// clang-format on
|
||||
>;
|
||||
#else
|
||||
using reduce_configuration_2_instances_threadwise = std::tuple<
|
||||
// clang-format off
|
||||
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
|
||||
ReductionConfiguration_2<0, 4, 4, 8, 1>,
|
||||
ReductionConfiguration_2<0, 4, 4, 4, 1>,
|
||||
ReductionConfiguration_2<0, 2, 2, 2, 1>,
|
||||
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 8>,
|
||||
ReductionConfiguration_2<1, 4, 1, 1, 4>,
|
||||
ReductionConfiguration_2<1, 2, 1, 1, 2>,
|
||||
|
||||
// special instances
|
||||
ReductionConfiguration_2<0, 1, 1, 3, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 5, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 7, 1>,
|
||||
ReductionConfiguration_2<0, 1, 1, 11, 1>,
|
||||
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 3>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 5>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 7>,
|
||||
ReductionConfiguration_2<1, 1, 1, 1, 11>
|
||||
// clang-format on
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
|
||||
using deviceReduceThreadWisePtrType = DeviceReducePtr<
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
ReduceTensorOp_t ReduceOpId,
|
||||
NanPropagation_t NanOpt,
|
||||
ReduceTensorIndices_t IndicesOpt>
|
||||
void add_device_reduce_instance_threadwise(
|
||||
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
|
||||
{
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
|
||||
AccElementwiseOperation;
|
||||
|
||||
constexpr bool Indexable =
|
||||
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
|
||||
ReduceOpId == ReduceTensorOp_t::AMAX);
|
||||
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
|
||||
|
||||
using cfg1 = ReductionConfiguration_1<256, 256, 1>;
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
|
||||
[&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(
|
||||
std::get<j.value>(reduce_configuration_2_instances_threadwise{}))>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
ReduceDims,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
NeedIndices,
|
||||
cfg1::BlockSize_,
|
||||
cfg1::MThreadClusterSize_,
|
||||
cfg1::KThreadClusterSize_,
|
||||
cfg2::MThreadSliceSize_,
|
||||
cfg2::KThreadSliceSize_,
|
||||
cfg2::InSrcVectorDim_,
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
});
|
||||
};
|
||||
|
||||
#define ADD_THREADWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
template void add_device_reduce_instance_threadwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
|
||||
|
||||
#define ADD_THREADWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_THREADWISE_INST_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
|
||||
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
extern template void add_device_reduce_instance_threadwise<inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
Rank, \
|
||||
Sequence<__VA_ARGS__>, \
|
||||
ReduceOpId, \
|
||||
NanOpt, \
|
||||
IndicesOpt>( \
|
||||
std::vector<DeviceReducePtr< \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
|
||||
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
|
||||
AccElementwiseOperation>> & \
|
||||
device_op_instances)
|
||||
|
||||
#define ADD_THREADWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
|
||||
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
|
||||
compT, \
|
||||
outT, \
|
||||
static_cast<ReduceTensorOp_t>(ReduceOpId), \
|
||||
static_cast<NanPropagation_t>(NanOpt), \
|
||||
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
|
||||
Rank, \
|
||||
__VA_ARGS__)
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,41 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
418
device_operation/include/device_reduce_multiblock_atomic_add.hpp
Normal file
418
device_operation/include/device_reduce_multiblock_atomic_add.hpp
Normal file
@@ -0,0 +1,418 @@
|
||||
#ifndef DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#define DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_multiblock_atomic_add.hpp"
|
||||
#include "gridwise_set_buffer_value.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceMultiBlockAtomicAdd
|
||||
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
|
||||
static constexpr bool support_AtomicAdd =
|
||||
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
|
||||
|
||||
static_assert(!NeedIndices && support_AtomicAdd,
|
||||
"MultiBlockAtomicAdd method can only be used with non-indiced operation and when "
|
||||
"having float/double output type!");
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
int blkGroupSize,
|
||||
int kBlockTileIterations)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}
|
||||
{
|
||||
(void)out_indices_dev;
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
kBlockTileIterations = iterations;
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize * blkGroupSize;
|
||||
|
||||
gridSize_pre =
|
||||
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
index_t blkGroupSize;
|
||||
index_t kBlockTileIterations;
|
||||
size_t gridSize;
|
||||
|
||||
size_t gridSize_pre;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
|
||||
const auto out_grid_desc_m = DeviceReduceMultiBlockAtomicAdd::MakeDst1dDescriptor(
|
||||
arg.outLengths_, arg.outStrides_);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce =
|
||||
GridwiseReduction_mk_to_m_multiblock_atomic_add<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
KernelTimer timer;
|
||||
|
||||
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>;
|
||||
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
printf("launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1} \n",
|
||||
arg.gridSize,
|
||||
BlockSize);
|
||||
printf("Warm up\n");
|
||||
|
||||
for(int i = 0; i < nrepeat + 1; i++)
|
||||
{
|
||||
if(i == 1)
|
||||
timer.Start();
|
||||
|
||||
launch_kernel(kernel_pre,
|
||||
dim3(arg.gridSize_pre),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
out_grid_desc_m,
|
||||
arg.out_dev_,
|
||||
static_cast<OutDataType>(0.0f));
|
||||
|
||||
launch_kernel(kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.blkGroupSize,
|
||||
arg.kBlockTileIterations,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.out_dev_);
|
||||
};
|
||||
|
||||
timer.End();
|
||||
|
||||
avg_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
if(static_cast<float>(pArg->beta_) != 0.0f)
|
||||
return (false);
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// cases with small reduce_total_length should be handled by the BlockWise method
|
||||
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
|
||||
return (false);
|
||||
|
||||
// This is very strong restriction, but needed to avoid some failure
|
||||
if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceMultiBlockAtomicAdd<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,419 @@
|
||||
#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
int Rank,
|
||||
typename ReduceDims,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceMultiBlockPartialReduce
|
||||
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
size_t GetWorkspaceSizeInBytes(const std::vector<int>& inLengths) override
|
||||
{
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
size_t workspace_size = invariant_total_length * blkGroupSize;
|
||||
|
||||
size_t wsSizeInBytes =
|
||||
!NeedIndices ? workspace_size * sizeof(AccDataType)
|
||||
: workspace_size * (sizeof(AccDataType) + sizeof(int)) + 64 + sizeof(int);
|
||||
|
||||
return (wsSizeInBytes);
|
||||
};
|
||||
|
||||
bool HasFurtherCall() override { return (true); };
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
int blkGroupSize,
|
||||
int kBlockTileIterations)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeWorkspace2dDescriptor(int outerLen, int blkGroupSize)
|
||||
{
|
||||
auto ws_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(outerLen, blkGroupSize));
|
||||
|
||||
const auto wsPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
|
||||
auto ws_desc_m_k_padded =
|
||||
transform_tensor_descriptor(ws_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, wsPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (ws_desc_m_k_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides,
|
||||
const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
out_indices_dev_{out_indices_dev},
|
||||
workspace_dev_{workspace_dev}
|
||||
{
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
kBlockTileIterations = iterations;
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize * blkGroupSize;
|
||||
|
||||
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
|
||||
invariant_total_length * blkGroupSize * sizeof(AccDataType), 64);
|
||||
|
||||
if constexpr(NeedIndices)
|
||||
workspace_indices_dev_ = reinterpret_cast<int*>(
|
||||
reinterpret_cast<char*>(workspace_dev_) + ws_buf2_bytes_offset);
|
||||
else
|
||||
workspace_indices_dev_ = nullptr;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
AccDataType* workspace_dev_;
|
||||
IndexDataType* workspace_indices_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
index_t blkGroupSize;
|
||||
index_t kBlockTileIterations;
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
|
||||
const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor(
|
||||
arg.invariant_total_length, arg.blkGroupSize);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using WorkspaceDesc_M_K = decltype(ws_desc_m_k);
|
||||
|
||||
using GridwiseReduce =
|
||||
GridwiseReduction_mk_to_mk_multiblock_partial_reduce<InDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
WorkspaceDesc_M_K,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_partial_reduce_multiblock<GridwiseReduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
WorkspaceDesc_M_K,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
ws_desc_m_k,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.blkGroupSize,
|
||||
arg.kBlockTileIterations,
|
||||
arg.in_dev_,
|
||||
arg.workspace_dev_,
|
||||
arg.workspace_indices_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(OutDstVectorSize != 1)
|
||||
return (false);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
// cases with small reduce_total_length should be handled by the BlockWise method
|
||||
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::vector<int> GetWorkspace2dLengths(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
return (
|
||||
std::vector<int>{static_cast<int>(pArg->invariant_total_length), pArg->blkGroupSize});
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const AccElementwiseOperation& acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
355
device_operation/include/device_reduce_threadwise.hpp
Normal file
355
device_operation/include/device_reduce_threadwise.hpp
Normal file
@@ -0,0 +1,355 @@
|
||||
#ifndef DEVICE_REDUCE_THREADWISE_HPP
|
||||
#define DEVICE_REDUCE_THREADWISE_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_2d_reduction_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
typename ReduceDims,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutElementwiseOperation>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
|
||||
"Threadwise can only be called with KThreadClusterSize be 1 !");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr bool BetaIsZero = NeedIndices;
|
||||
|
||||
using InvariantDims = decltype(get_invariant_dims<Rank, ReduceDims>());
|
||||
|
||||
static constexpr index_t srcDims = Rank;
|
||||
static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size();
|
||||
static constexpr bool reduceAllDims = (InvariantDims::Size() == 0);
|
||||
|
||||
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<srcDims>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDims)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto toReduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(toReduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen;
|
||||
|
||||
auto in_grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(outerLen, inPad_M),
|
||||
make_right_pad_transform(innerLen, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<dstDims>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto outerLen = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(outerLen, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_indices_dev,
|
||||
AccDataType* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op)
|
||||
: in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}
|
||||
{
|
||||
(void)workspace_dev;
|
||||
|
||||
inLengths_ = inLengths;
|
||||
inStrides_ = inStrides;
|
||||
outLengths_ = outLengths;
|
||||
outStrides_ = outStrides;
|
||||
|
||||
in_elementwise_op_ = in_elementwise_op;
|
||||
acc_elementwise_op_ = acc_elementwise_op;
|
||||
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<OutDataType>(beta);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, ReduceDims>(inLengths);
|
||||
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)];
|
||||
|
||||
reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)];
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
}
|
||||
|
||||
std::vector<int> inLengths_;
|
||||
std::vector<int> inStrides_;
|
||||
std::vector<int> outLengths_;
|
||||
std::vector<int> outStrides_;
|
||||
|
||||
AccDataType alpha_;
|
||||
OutDataType beta_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
IndexDataType* out_indices_dev_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
OutElementwiseOperation acc_elementwise_op_;
|
||||
|
||||
int invariant_lowest_length;
|
||||
int reduce_lowest_length;
|
||||
size_t invariant_total_length;
|
||||
size_t reduce_total_length;
|
||||
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
const auto in_grid_desc_m_k =
|
||||
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
|
||||
const auto out_grid_desc_m =
|
||||
DeviceReduceThreadWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
|
||||
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
|
||||
using OutGridDesc_M = decltype(out_grid_desc_m);
|
||||
|
||||
using GridwiseReduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
PropagateNan,
|
||||
BetaIsZero,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
|
||||
NeedIndices,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_,
|
||||
arg.out_indices_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(InvariantDims::Size() == 0)
|
||||
return (false);
|
||||
|
||||
if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
// TODO: remove this. Should return true, as long as this DeviceOP instance support this
|
||||
// case for bigger reduce_total_length size, we are supposed to use BlockWise method for
|
||||
// better performance
|
||||
if(pArg->reduce_total_length / KThreadSliceSize >= 32)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<int>& inLengths,
|
||||
const std::vector<int>& inStrides,
|
||||
const std::vector<int>& outLengths,
|
||||
const std::vector<int>& outStrides,
|
||||
float alpha,
|
||||
float beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
void* out_indices_dev,
|
||||
void* workspace_dev,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
static_cast<IndexDataType*>(out_indices_dev),
|
||||
static_cast<AccDataType*>(workspace_dev),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReducceThreadWise<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
169
device_operation/include/reduction_operator_mapping.hpp
Normal file
169
device_operation/include/reduction_operator_mapping.hpp
Normal file
@@ -0,0 +1,169 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_REDUCTION_OPERATOR_MAPPING_HPP
|
||||
#define CK_REDUCTION_OPERATOR_MAPPING_HPP
|
||||
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their
|
||||
// respective functor classes.
|
||||
// The boolean member "indexable" are also provided in reduce_binary_operactor for
|
||||
// easier checking by the upper-layer codes in the kernels.
|
||||
|
||||
template <typename T, ReduceTensorOp_t Op>
|
||||
struct reduce_binary_operator;
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
|
||||
{
|
||||
using opType = reduce::Mul<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
|
||||
{
|
||||
using opType = reduce::Min<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
|
||||
{
|
||||
using opType = reduce::Max<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
|
||||
{
|
||||
using opType = reduce::AMax<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
static constexpr bool indexable = false;
|
||||
};
|
||||
|
||||
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
|
||||
// functor classes.
|
||||
// The two unary functors are called before and afer the Reduction is executed respectively
|
||||
template <typename T, ReduceTensorOp_t Op, bool IsFirstReduce, bool IsLastReduce>
|
||||
struct reduce_unary_operator
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
};
|
||||
|
||||
template <typename T, bool IsFirstReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::AVG, IsFirstReduce, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
|
||||
};
|
||||
|
||||
template <typename T, bool IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM1, true, IsLastReduce>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
};
|
||||
|
||||
template <typename T, bool IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::AMAX, true, IsLastReduce>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, false>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, false, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
|
||||
};
|
||||
|
||||
} // end of namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,34 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
#include "device_reduce_instance_blockwise_second_call.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,19 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,46 @@
|
||||
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
|
||||
// Will be moved to use MultiBlockAtomicAdd
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_reduce_instance {
|
||||
|
||||
// clang-format off
|
||||
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 0);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
|
||||
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
|
||||
// clang-format on
|
||||
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user