From dfb80c4e39ec7b304c3ebc88bab2a204bc4906b9 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Wed, 29 Sep 2021 23:12:11 +0800 Subject: [PATCH] [Enhancements] Several bugfixes and refactoring of dynamic generic reduction (#1156) * Squashed 'src/composable_kernel/' content from commit f6edda611 git-subtree-dir: src/composable_kernel git-subtree-split: f6edda6119ebbb237dfa6270797b34f960d7b190 * add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files * Squashed 'src/composable_kernel/' changes from f6edda611..5781adf5c 5781adf5c Update develop (#5) (#6) 97e6d514f Merge pull request #4 from ROCmSoftwarePlatform/separate_online_compile 7b1ec41e5 refactor 49c33aaea refactor 54b3e73d1 rename git-subtree-dir: src/composable_kernel git-subtree-split: 5781adf5cf4ac753e2e36da7385791775b744bf7 * fix * refactor * remove online compilation from CK * refactor * fix * add ctest * tidy * add tidy * tidy * tidy * tidy * tidy * tidy * tidy * tidy * tidy * tidy * add c-style pointer cast * vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast * fix clang warning suppression * tidy * suppress cppcheck * fix enum issue * revert chagnes to hip build * fix kernel filename * update CK build script * rename * rename * make innner product compatiable on gfx900 * Update src/include/miopen/solver/ck_utility_common.hpp Co-authored-by: JD * compiler parameter use stream * use int instead of index_t in kernel wrapper * DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element * refactor * refactor * change cmakelist * change ck common utility * fix * Squashed 'src/composable_kernel/' changes from 5781adf5c..31b403526 31b403526 Merge pull request #16 from ROCmSoftwarePlatform/develop b62bf8c3f Merge pull request #14 from ROCmSoftwarePlatform/miopen_downstream_init_integration ccc4a1d36 Merge pull request #8 from ROCmSoftwarePlatform/miopen_downstream_init_integration 67ad47e7c refactor 16effa767 refactor a91b68dfc DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element 2cbabbba5 use int instead of index_t in kernel wrapper 0834bc763 compiler parameter use stream f2ac7832c make innner product compatiable on gfx900 4e57b30a6 rename c03045ce2 rename b2589957f update CK build script 2c48039d0 fix kernel filename d626dccc9 fix enum issue 643ebd4f3 tidy ddd49ec9e fix clang warning suppression 4f566c622 vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast 172036d72 add c-style pointer cast 76f313193 tidy d18428901 tidy f885c131d tidy 80120f0a0 tidy c3efeb5e2 tidy 56fc0842b tidy 54fba515b tidy e62bae7a4 tidy 24c872894 add tidy 61487e0a0 fix ae98b52ad remove online compilation from CK cb9542131 refactor 73ca97015 Merge commit '437cc595c6e206dfebb118985b5171bbc1e29eab' into composable_kernel_init_integration_v3 3b8664611 Merge pull request #7 from ROCmSoftwarePlatform/master d09ea4f4e Update develop (#5) 3d32ae940 add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files git-subtree-dir: src/composable_kernel git-subtree-split: 31b403526ec54abf13c4bb58dfb6635b4d2aa619 * Tiny fix in using data type template parameters in blockwise and direct_threadwise kernel * Fix with regard to implementing GetZeroVal() in both kernel and host * Avoid convert to compType from dstDataType before writting the output value * Add half_t support to NumericLimits and make constexpr GetZeroVal() of binary operator * Add CONSTANT decorator for descriptor read buffer * Use get_thread_local_1d_id() for thread local Id * Rename GetZeroVal() to GetReductionZeroVal() in the kernels * Remove constexpr from initialized zeroVal and tiny fix in reduction_operator.hpp * Occasional tiny simplification and update in the kernel files * Update in src/reducetensor.cpp for consistent IDs passing to the kernel * Update to re-order tensor dimensions on the host, split second_call kernel wrapper files and simplify reduce_all kernel wrappers * Update to remove OpenCL tidy checking failures * Small updates in src/reducetensor.cpp * Update for better readability * Remove unused codes and not-needed template parameters in the kernel wrappers Co-authored-by: Chao Liu Co-authored-by: JD --- ...ridwise_generic_2d_reduction_blockwise.hpp | 38 ++- ...generic_2d_reduction_direct_threadwise.hpp | 40 ++-- ...e_generic_2d_reduction_direct_warpwise.hpp | 36 ++- ...idwise_generic_2d_reduction_multiblock.hpp | 4 +- .../reduction_functions_blockwise.hpp | 4 +- .../reduction_functions_warpwise.hpp | 12 +- .../include/utility/data_type.hpp | 27 ++- .../include/utility/reduction_common.hpp | 59 +---- .../include/utility/reduction_enums.hpp | 66 ++++++ .../include/utility/reduction_operator.hpp | 65 +++-- ...n_first_call_blockwise_reduce_all_dims.cpp | 88 ++----- ...rst_call_blockwise_reduce_partial_dims.cpp | 39 +-- ..._first_call_multiblock_reduce_all_dims.cpp | 89 ++----- ...st_call_multiblock_reduce_partial_dims.cpp | 41 ++-- ..._first_call_threadwise_reduce_all_dims.cpp | 90 ++----- ...st_call_threadwise_reduce_partial_dims.cpp | 41 ++-- ...on_first_call_warpwise_reduce_all_dims.cpp | 91 ++----- ...irst_call_warpwise_reduce_partial_dims.cpp | 41 ++-- ..._second_call_blockwise_reduce_all_dims.cpp | 205 ++++++++++++++++ ...nd_call_blockwise_reduce_partial_dims.cpp} | 43 +--- ...second_call_threadwise_reduce_all_dims.cpp | 222 ++++++++++++++++++ ...d_call_threadwise_reduce_partial_dims.cpp} | 45 +--- ...n_second_call_warpwise_reduce_all_dims.cpp | 221 +++++++++++++++++ ...ond_call_warpwise_reduce_partial_dims.cpp} | 45 +--- 24 files changed, 1031 insertions(+), 621 deletions(-) create mode 100644 composable_kernel/include/utility/reduction_enums.hpp create mode 100644 composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp rename composable_kernel/src/kernel_wrapper/{gridwise_generic_reduction_second_call_blockwise.cpp => gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp} (87%) create mode 100644 composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp rename composable_kernel/src/kernel_wrapper/{gridwise_generic_reduction_second_call_threadwise.cpp => gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp} (87%) create mode 100644 composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp rename composable_kernel/src/kernel_wrapper/{gridwise_generic_reduction_second_call_warpwise.cpp => gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp} (87%) diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp index 20075526b2..c635da57f4 100644 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp @@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise // LDS __shared__ compType p_in_block_buffer[BlockBufferSize]; - auto zeroVal = opReduce::GetZeroVal(); + const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise threadwise_dst_load.Run( dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -281,7 +285,7 @@ struct GridwiseReduction_xy_to_x_blockwise ThreadClusterLengths, Sequence<0, 1>, srcDataType, - dstDataType, + compType, src2dDescType, decltype(in_block_desc), Sequence<0, 1>, @@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3(ws_values_global, @@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3 @@ -200,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise { (void)ws_indices_global; - const auto zeroVal = opReduce::GetZeroVal(); + const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -232,7 +236,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3(ws_values_global, @@ -377,7 +385,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise threadwise_dst_load.Run( dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf(I0) * beta); + dstValue_buf(I0) += priorDstValue_buf(I0) * beta; } auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3(ws_values_global, @@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3{}( [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); @@ -84,7 +84,7 @@ struct WarpReduce // since for fp16, built-in shuffling functions is not provided by HIP __device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); static_for<0, ThreadBufferLen, 1>{}( [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); @@ -138,7 +138,7 @@ struct WarpReduce int& accuIndex, int indexStart) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize; @@ -170,7 +170,7 @@ struct WarpReduce int& accuIndex, int indexStart) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; index_t thread_id = get_thread_local_1d_id(); index_t warpId = thread_id / warpSize; @@ -278,7 +278,7 @@ struct WarpReduceWithIndicesInput compType& accuData, int& accuIndex) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; static_for<0, ThreadBufferLen, 1>{}([&](auto I) { @@ -307,7 +307,7 @@ struct WarpReduceWithIndicesInput compType& accuData, int& accuIndex) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; index_t thread_id = get_thread_local_1d_id(); index_t warpId = thread_id / warpSize; diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp index bfaac8a939..07eceb84cf 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/composable_kernel/include/utility/data_type.hpp @@ -1008,20 +1008,27 @@ struct inner_product_with_conversion }; template -struct NumericLimits; +struct NumericLimits +{ + __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } + + __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } + + __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } +}; template <> -struct NumericLimits +struct NumericLimits { - __host__ __device__ static constexpr int32_t Min() - { - return std::numeric_limits::min(); - } + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; - __host__ __device__ static constexpr int32_t Max() - { - return std::numeric_limits::max(); - } + __host__ __device__ static constexpr half_t Min() { return as_type(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return as_type(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return as_type(binary_lowest); } }; } // namespace ck diff --git a/composable_kernel/include/utility/reduction_common.hpp b/composable_kernel/include/utility/reduction_common.hpp index 139a18c2a4..ff574c315c 100644 --- a/composable_kernel/include/utility/reduction_common.hpp +++ b/composable_kernel/include/utility/reduction_common.hpp @@ -26,76 +26,25 @@ #ifndef CK_REDUCTION_COMMON_HPP #define CK_REDUCTION_COMMON_HPP -// this enumerate should be synchronized with include/miopen/reduce_common.hpp +#include "reduction_enums.hpp" + namespace ck { -enum class ReductionMethod_t -{ - DirectThreadWise = 1, - DirectWarpWise = 2, - BlockWise = 3, - MultiBlock = 4 -}; // end of namespace ck - -enum class ReduceTensorOp_t -{ - ADD = 0, - MUL = 1, - MIN = 2, - MAX = 3, - AMAX = 4, - AVG = 5, - NORM1 = 6, - NORM2 = 7, - // MUL_NO_ZEROS = 8, -}; - -enum class NanPropagation_t -{ - NOT_PROPAGATE_NAN = 0, - PROPAGATE_NAN = 1, -}; - -enum class ReduceTensorIndices_t -{ - NO_INDICES = 0, - FLATTENED_INDICES = 1, -}; - -enum class IndicesType_t -{ - INDICES_32BIT = 0, - INDICES_64BIT = 1, - INDICES_16BIT = 2, - INDICES_8BIT = 3, -}; struct float_equal_one { - template - __device__ static inline bool apply(T x) - { - return x <= type_convert{}(1.0f) and x >= type_convert{}(1.0f); - } - template __device__ inline bool operator()(T x) { - return (float_equal_one::apply(x)); + return x <= static_cast(1.0f) and x >= static_cast(1.0f); }; }; struct float_equal_zero { - template - __device__ static inline bool apply(T x) - { - return x <= type_convert{}(0.0f) and x >= type_convert{}(0.0f); - } - template __device__ inline bool operator()(T x) { - return (float_equal_zero::apply(x)); + return x <= static_cast(0.0f) and x >= static_cast(0.0f); }; }; diff --git a/composable_kernel/include/utility/reduction_enums.hpp b/composable_kernel/include/utility/reduction_enums.hpp new file mode 100644 index 0000000000..e97108179e --- /dev/null +++ b/composable_kernel/include/utility/reduction_enums.hpp @@ -0,0 +1,66 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_REDUCTION_ENUMS_HPP +#define CK_REDUCTION_ENUMS_HPP + +namespace ck { + +enum class ReduceTensorOp_t +{ + ADD = 0, + MUL = 1, + MIN = 2, + MAX = 3, + AMAX = 4, + AVG = 5, + NORM1 = 6, + NORM2 = 7, + // MUL_NO_ZEROS = 8, +}; + +enum class NanPropagation_t +{ + NOT_PROPAGATE_NAN = 0, + PROPAGATE_NAN = 1, +}; + +enum class ReduceTensorIndices_t +{ + NO_INDICES = 0, + FLATTENED_INDICES = 1, +}; + +enum class IndicesType_t +{ + INDICES_32BIT = 0, + INDICES_64BIT = 1, + INDICES_16BIT = 2, + INDICES_8BIT = 3, +}; + +}; // end of namespace ck + +#endif diff --git a/composable_kernel/include/utility/reduction_operator.hpp b/composable_kernel/include/utility/reduction_operator.hpp index 269671a400..c0afbec869 100644 --- a/composable_kernel/include/utility/reduction_operator.hpp +++ b/composable_kernel/include/utility/reduction_operator.hpp @@ -35,10 +35,12 @@ namespace reduce { // Every binary operator used in reduction is represented by a templated functor class. Each functor // class must provide at least // three members: -// 1) GetZeroVal() -- the interface to return the "identity element" for the binary operator, -// "identity element" is the unique +// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary +// operator, "identity element" is the unique // element in the algebraic space that doesn't affect the value of other elements -// when operated with any of them. +// when operated against them, and the concept is similar to zero vector in +// vector space +// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). // 2) indexable -- boolean value indicating whether indices of the operated elements could be // recorded. Usually, Min/Max operator could // need to record the indices of elements. For operator like Add/Mul, no need to @@ -58,7 +60,7 @@ struct Add { using dataType = T; - __device__ static T GetZeroVal() { return type_convert{}(0.0f); }; + __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } @@ -70,7 +72,7 @@ struct Mul { using dataType = T; - __device__ static T GetZeroVal() { return type_convert{}(1.0f); }; + __device__ static constexpr T GetReductionZeroVal() { return static_cast(1.0f); }; __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } @@ -82,7 +84,7 @@ struct Max { using dataType = T; - __device__ static T GetZeroVal() { return std::numeric_limits::min(); }; + __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Lowest(); }; __device__ inline constexpr void operator()(T& a, T b) const { @@ -107,7 +109,7 @@ struct Min { using dataType = T; - __device__ static T GetZeroVal() { return std::numeric_limits::max(); }; + __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Max(); }; __device__ inline constexpr void operator()(T& a, T b) const { @@ -127,16 +129,29 @@ struct Min static constexpr bool indexable = true; }; -template <> -__device__ half_t Max::GetZeroVal() +template +struct AMax { - return type_convert{}(std::numeric_limits::min()); -}; + using dataType = T; -template <> -__device__ half_t Min::GetZeroVal() -{ - return type_convert{}(std::numeric_limits::max()); + __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; + + __device__ inline constexpr void operator()(T& a, T b) const + { + if(a < b) + a = b; + } + + __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + if(a < b) + { + a = b; + changed = true; + } + } + + static constexpr bool indexable = true; }; // Unary operators are usually called element-wisely before the reduction is executed on the @@ -268,7 +283,7 @@ struct unary_sqrt // The templated struct reduce_binary_operator maps the enum Ids of binary operators to their // respective functor classes. -// The "GetZeroVal()" interface and boolean member "indexable" are also provided in +// The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in // reduce_binary_operactor for // easier checking by the upper-layer codes in the kernels. @@ -281,8 +296,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; @@ -292,8 +305,6 @@ struct reduce_binary_operator using opType = reduce::Mul; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Mul::GetZeroVal(); }; - static constexpr bool indexable = reduce::Mul::indexable; }; @@ -303,8 +314,6 @@ struct reduce_binary_operator using opType = reduce::Min; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Min::GetZeroVal(); }; - static constexpr bool indexable = reduce::Min::indexable; }; @@ -314,19 +323,15 @@ struct reduce_binary_operator using opType = reduce::Max; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Max::GetZeroVal(); }; - static constexpr bool indexable = reduce::Max::indexable; }; template struct reduce_binary_operator { - using opType = reduce::Max; + using opType = reduce::AMax; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Max::GetZeroVal(); }; - static constexpr bool indexable = reduce::Max::indexable; }; @@ -336,8 +341,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; @@ -347,8 +350,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; @@ -358,8 +359,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp index e16010dee1..ca6b415910 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp @@ -43,9 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)GridSize; @@ -132,18 +107,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -157,14 +128,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; @@ -179,30 +144,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -217,12 +180,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -235,25 +192,22 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; -template +template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) { if constexpr(need_padding) @@ -277,15 +231,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp index cba7ffe295..a3daeaf163 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -133,14 +122,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -179,16 +166,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; }; @@ -278,15 +265,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp index 34b877027c..81899dfb02 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp @@ -43,10 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)GridSize; @@ -132,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -157,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; const index_t reduceSizePerBlock = @@ -181,30 +145,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -219,12 +181,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -237,23 +193,20 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -279,16 +232,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_dst_global; (void)indices_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp index 9c7318dc15..0e578f4d1d 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -180,16 +167,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; }; @@ -279,16 +266,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_dst_global; (void)indices_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp index 8e67d1faa1..e63a1254e4 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp @@ -43,9 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)BlkGroupSize; @@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = GredThreadBufferLength; @@ -178,12 +143,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -191,31 +156,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, { const auto dstPad = GridSize * BlockSize - invariantLen; auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, + transform_tensor_descriptor(dstdDesc, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -230,12 +193,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -248,23 +205,20 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -290,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp index fdbcda64ba..698f740058 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -178,12 +165,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -195,12 +182,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; @@ -291,15 +278,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp index 8aa1376c3a..4a607372e9 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp @@ -43,9 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)BlkGroupSize; @@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; @@ -179,12 +144,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -192,31 +157,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, { const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, + transform_tensor_descriptor(dstDesc, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -231,12 +194,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -249,23 +206,19 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_12 typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -291,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp index e18d623fe5..a641527900 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -179,12 +166,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -196,12 +183,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; @@ -292,15 +279,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp new file mode 100644 index 0000000000..7e9d46612e --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp @@ -0,0 +1,205 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "config.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "tensor_descriptor_helper.hpp" +#include "data_type_enum_helper.hpp" +#include "reduction_common.hpp" +#include "gridwise_generic_2d_reduction_blockwise.hpp" + +using namespace ck; + +using srcDataType = + typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; +using dstDataType = + typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; +using compType = + typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable + +constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); +constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 + ? NanPropagation_t::NOT_PROPAGATE_NAN + : NanPropagation_t::PROPAGATE_NAN; +constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 + ? ReduceTensorIndices_t::NO_INDICES + : ReduceTensorIndices_t::FLATTENED_INDICES; + +constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); +constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); + +constexpr bool indexable = reduce_binary_operator::indexable; +constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); + +constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable + +extern "C" __global__ void +gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) +{ + (void)GridSize; + + void* p_src2dDesc = ws_global; + void* p_dst1dDesc = static_cast(ws_global) + 2048; + + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); + + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + const index_t invariantLen = dstDesc.GetLength(Number<0>{}); + const index_t toReduceLen = BlkGroupSize; + + auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); + + constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; + + if constexpr(src2d_need_padding) + { + const auto srcPad = + ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; + + auto src2dDesc_2 = + transform_tensor_descriptor(src2dDesc, + make_tuple(make_pass_through_transform(invariantLen), + make_pad_transform(toReduceLen, 0, srcPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc; + } + + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; +}; + +struct get_ref_desc_types +{ + static constexpr auto ref_tupleDstLengths = make_tuple(8); + static constexpr auto ref_dstDesc = + make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); + + static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); + static constexpr index_t ref_toReduceLen = 8; + + static constexpr auto ref_src2dDesc = + make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); + + using refType_src2dDesc = decltype(ref_src2dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); + + // used by the BlockWise and MultiBlock method + using refType_src2dDesc_padded_34 = decltype( + transform_tensor_descriptor(ref_src2dDesc, + make_tuple(make_pass_through_transform(ref_invariantLen), + make_pad_transform(ref_toReduceLen, 0, 2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}))); + + using refType_dst1dDesc_padded = + decltype(transform_tensor_descriptor(ref_dstDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{}))); +}; + +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; + +template +static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_src2dDesc)); + else + return (*reinterpret_cast(p_src2dDesc)); +}; + +template +static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_dst1dDesc)); + else + return (*reinterpret_cast(p_dst1dDesc)); +}; + +extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, + float alpha, + const void* __restrict__ p_src_global, + float beta, + void* __restrict__ p_dst_global, + const void CONSTANT* ws_global, + long ws_buf2_bytes_offset, + void* __restrict__ indices_global) +{ + (void)p_src_global; + + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); + + const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); + const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); + + using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; + + void* const ws_buf2_global = + ws_buf2_bytes_offset > 0 + ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) + : nullptr; + + constexpr int RunId = need_indices ? 3 : 1; + gridwise_2d_reduce::template Run( + src2dDesc, + dst1dDesc, + origReduceLen, + alpha, + static_cast(ws_buf1_global), + beta, + static_cast(p_dst_global), + static_cast(ws_buf2_global), + static_cast(indices_global)); +}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp similarity index 87% rename from composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise.cpp rename to composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp index b7b58cbb90..3f37d01e21 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp @@ -42,12 +42,8 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable -constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 ? NanPropagation_t::NOT_PROPAGATE_NAN @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -152,20 +138,20 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; }; -template +template struct get_ref_desc_types { static constexpr auto ref_tupleDstLengths = @@ -203,16 +189,11 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}))); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_34 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -237,15 +218,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_src_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp new file mode 100644 index 0000000000..77841d1312 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp @@ -0,0 +1,222 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "config.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "tensor_descriptor_helper.hpp" +#include "data_type_enum_helper.hpp" +#include "reduction_common.hpp" +#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" + +using namespace ck; + +using srcDataType = + typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; +using dstDataType = + typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; +using compType = + typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable + +using toReduceDims = Sequence; +using invariantDims = Sequence; // this could be empty + +constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); +constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 + ? NanPropagation_t::NOT_PROPAGATE_NAN + : NanPropagation_t::PROPAGATE_NAN; +constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 + ? ReduceTensorIndices_t::NO_INDICES + : ReduceTensorIndices_t::FLATTENED_INDICES; + +constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); +constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); + +constexpr bool indexable = reduce_binary_operator::indexable; +constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); + +constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable + +extern "C" __global__ void +gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) +{ + (void)BlkGroupSize; + + void* p_src2dDesc = ws_global; + void* p_dst1dDesc = static_cast(ws_global) + 2048; + + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); + + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + const index_t invariantLen = dstDesc.GetLength(Number<0>{}); + const index_t toReduceLen = BlkGroupSize; + + auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); + + constexpr auto copySliceLen = GredThreadBufferLength; + + if constexpr(src2d_need_padding) + { + const auto srcPad1 = GridSize * BlockSize - invariantLen; + const auto srcPad2 = + ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; + auto src2dDesc_2 = + transform_tensor_descriptor(src2dDesc, + make_tuple(make_pad_transform(invariantLen, 0, srcPad1), + make_pad_transform(toReduceLen, 0, srcPad2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc; + } + + if constexpr(dst1d_need_padding) + { + const auto dstPad = GridSize * BlockSize - invariantLen; + auto dst1dDesc_2 = + transform_tensor_descriptor(dstDesc, + make_tuple(make_pad_transform(invariantLen, 0, dstPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dst1dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; + } +}; + +struct get_ref_desc_types +{ + static constexpr auto ref_tupleDstLengths = make_tuple(8); + static constexpr auto ref_dstDesc = + make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); + + static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); + static constexpr index_t ref_toReduceLen = 8; + + static constexpr auto ref_src2dDesc = + make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); + + using refType_src2dDesc = decltype(ref_src2dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); + + // used by the DirectThreadWise and DirectWarpWise method + using refType_src2dDesc_padded_12 = + decltype(transform_tensor_descriptor(ref_src2dDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2), + make_pad_transform(ref_toReduceLen, 0, 2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}))); + + using refType_dst1dDesc_padded = + decltype(transform_tensor_descriptor(ref_dstDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{}))); +}; + +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; + +template +static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_src2dDesc)); + else + return (*reinterpret_cast(p_src2dDesc)); +}; + +template +static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_dst1dDesc)); + else + return (*reinterpret_cast(p_dst1dDesc)); +}; + +extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, + float alpha, + const void* __restrict__ p_src_global, + float beta, + void* __restrict__ p_dst_global, + const void CONSTANT* ws_global, + long ws_buf2_bytes_offset, + void* __restrict__ indices_global) +{ + (void)p_src_global; + + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); + + const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); + const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); + + using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; + + void* const ws_buf2_global = + ws_buf2_bytes_offset > 0 + ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) + : nullptr; + + constexpr int RunId = need_indices ? 3 : 1; + gridwise_2d_reduce::template Run( + src2dDesc, + dst1dDesc, + origReduceLen, + alpha, + static_cast(ws_buf1_global), + beta, + static_cast(p_dst_global), + static_cast(ws_buf2_global), + static_cast(indices_global)); +}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp similarity index 87% rename from composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise.cpp rename to composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp index ef88547028..2de461ad0f 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp @@ -42,12 +42,8 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable -constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 ? NanPropagation_t::NOT_PROPAGATE_NAN @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -152,12 +138,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -169,17 +155,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_tupleDstLengths = @@ -217,16 +203,11 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}))); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -251,15 +232,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_src_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp new file mode 100644 index 0000000000..1ba5e49657 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp @@ -0,0 +1,221 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "config.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "tensor_descriptor_helper.hpp" +#include "data_type_enum_helper.hpp" +#include "reduction_common.hpp" +#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" + +using namespace ck; + +using srcDataType = + typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; +using dstDataType = + typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; +using compType = + typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable + +constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); +constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 + ? NanPropagation_t::NOT_PROPAGATE_NAN + : NanPropagation_t::PROPAGATE_NAN; +constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 + ? ReduceTensorIndices_t::NO_INDICES + : ReduceTensorIndices_t::FLATTENED_INDICES; + +constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); +constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); + +constexpr bool indexable = reduce_binary_operator::indexable; +constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); + +constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable + +extern "C" __global__ void +gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) +{ + (void)BlkGroupSize; + + void* p_src2dDesc = ws_global; + void* p_dst1dDesc = static_cast(ws_global) + 2048; + + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); + + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + const index_t invariantLen = dstDesc.GetLength(Number<0>{}); + const index_t toReduceLen = BlkGroupSize; + + auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); + + constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; + + if constexpr(src2d_need_padding) + { + const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; + const auto srcPad2 = + ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; + + auto src2dDesc_2 = + transform_tensor_descriptor(src2dDesc, + make_tuple(make_pad_transform(invariantLen, 0, srcPad1), + make_pad_transform(toReduceLen, 0, srcPad2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc; + } + + if constexpr(dst1d_need_padding) + { + const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; + auto dst1dDesc_2 = + transform_tensor_descriptor(dstDesc, + make_tuple(make_pad_transform(invariantLen, 0, dstPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dst1dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; + } +}; + +struct get_ref_desc_types +{ + static constexpr auto ref_tupleDstLengths = make_tuple(8); + static constexpr auto ref_dstDesc = + make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); + + static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); + static constexpr index_t ref_toReduceLen = 8; + + static constexpr auto ref_src2dDesc = + make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); + + using refType_src2dDesc = decltype(ref_src2dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); + + // used by the DirectThreadWise and DirectWarpWise method + using refType_src2dDesc_padded_12 = + decltype(transform_tensor_descriptor(ref_src2dDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2), + make_pad_transform(ref_toReduceLen, 0, 2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}))); + + using refType_dst1dDesc_padded = + decltype(transform_tensor_descriptor(ref_dstDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{}))); +}; + +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; + +template +static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_src2dDesc)); + else + return (*reinterpret_cast(p_src2dDesc)); +}; + +template +static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_dst1dDesc)); + else + return (*reinterpret_cast(p_dst1dDesc)); +}; + +extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, + float alpha, + const void* __restrict__ p_src_global, + float beta, + void* __restrict__ p_dst_global, + const void CONSTANT* ws_global, + long ws_buf2_bytes_offset, + void* __restrict__ indices_global) +{ + (void)p_src_global; + + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); + + const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); + const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); + + using gridwise_2d_reduce = + GridwiseReduction_xy_to_x_direct_warpwise; + + void* const ws_buf2_global = + ws_buf2_bytes_offset > 0 + ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) + : nullptr; + + constexpr int RunId = need_indices ? 3 : 1; + gridwise_2d_reduce::template Run( + src2dDesc, + dst1dDesc, + origReduceLen, + alpha, + static_cast(ws_buf1_global), + beta, + static_cast(p_dst_global), + static_cast(ws_buf2_global), + static_cast(indices_global)); +}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp similarity index 87% rename from composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise.cpp rename to composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp index 53b0e1e759..aef1545f11 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp @@ -42,12 +42,8 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable -constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 ? NanPropagation_t::NOT_PROPAGATE_NAN @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -153,12 +139,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -170,17 +156,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_tupleDstLengths = @@ -218,16 +204,11 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}))); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -252,15 +233,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_src_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc);