From 44f52fd081c523363b512e2b034e80e3afe44e78 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 23 Jun 2023 12:45:09 +0000 Subject: [PATCH] Add utility/get_shift.hpp --- .../gpu/block/blockwise_welford.hpp | 2 +- .../block/reduction_functions_blockwise.hpp | 2 +- include/ck/utility/get_id.hpp | 24 +++++++++++++++++++ include/ck/utility/get_shift.hpp | 20 ++++++++++++++++ include/ck/utility/reduction_common.hpp | 12 ---------- 5 files changed, 46 insertions(+), 14 deletions(-) create mode 100644 include/ck/utility/get_shift.hpp diff --git a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp index 7506f7072b..08bc409d28 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/utility/reduction_common.hpp" +#include "ck/utility/get_shift.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp index 6c13513cfb..82667e2352 100644 --- a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp +++ b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/utility/reduction_common.hpp" +#include "ck/utility/get_shift.hpp" #include "ck/utility/reduction_functions_accumulate.hpp" namespace ck { diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index 77564c6130..c872a1a0e5 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/ck.hpp" +#include "ck/utility/get_shift.hpp" namespace ck { @@ -19,6 +20,29 @@ __device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } +// get_wave_id() does the same thing as get_warp_local_1d_id(), except that +// it tries to save the result in sgpr +#if defined(__gfx90a__) +static __device__ inline index_t get_wave_id() +{ + int thread_id = threadIdx.x; + int tmp_int; + int wave_id; + constexpr index_t shift = get_shift(); + + // clang-format off + __asm__ volatile("v_lshrrev_b32 %1, %3, %2 \n\ + v_readfirstlane_b32 %0, %1" + : "=s"(wave_id), "=v"(tmp_int) + : "v"(thread_id), "i"(shift)); + // clang-format on + + return wave_id; +}; +#else +static __device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); }; +#endif + __device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_grid_size() { return gridDim.x; } diff --git a/include/ck/utility/get_shift.hpp b/include/ck/utility/get_shift.hpp new file mode 100644 index 0000000000..0a93081cfd --- /dev/null +++ b/include/ck/utility/get_shift.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +static constexpr __device__ index_t get_shift() +{ + return (get_shift() + 1); +}; + +template <> +constexpr __device__ index_t get_shift<1>() +{ + return (0); +} + +} // namespace ck diff --git a/include/ck/utility/reduction_common.hpp b/include/ck/utility/reduction_common.hpp index 3777d297c8..75fdd85825 100644 --- a/include/ck/utility/reduction_common.hpp +++ b/include/ck/utility/reduction_common.hpp @@ -25,16 +25,4 @@ struct float_equal_zero }; }; -template -static constexpr __device__ index_t get_shift() -{ - return (get_shift() + 1); -}; - -template <> -constexpr __device__ index_t get_shift<1>() -{ - return (0); -} - } // namespace ck