mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add utility/get_shift.hpp
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/get_shift.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/get_shift.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/get_shift.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -19,6 +20,29 @@ __device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x +
|
||||
|
||||
__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); }
|
||||
|
||||
// get_wave_id() does the same thing as get_warp_local_1d_id(), except that
|
||||
// it tries to save the result in sgpr
|
||||
#if defined(__gfx90a__)
|
||||
static __device__ inline index_t get_wave_id()
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int tmp_int;
|
||||
int wave_id;
|
||||
constexpr index_t shift = get_shift<warpSize>();
|
||||
|
||||
// clang-format off
|
||||
__asm__ volatile("v_lshrrev_b32 %1, %3, %2 \n\
|
||||
v_readfirstlane_b32 %0, %1"
|
||||
: "=s"(wave_id), "=v"(tmp_int)
|
||||
: "v"(thread_id), "i"(shift));
|
||||
// clang-format on
|
||||
|
||||
return wave_id;
|
||||
};
|
||||
#else
|
||||
static __device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); };
|
||||
#endif
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
__device__ index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
20
include/ck/utility/get_shift.hpp
Normal file
20
include/ck/utility/get_shift.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
static constexpr __device__ index_t get_shift()
|
||||
{
|
||||
return (get_shift<N / 2>() + 1);
|
||||
};
|
||||
|
||||
template <>
|
||||
constexpr __device__ index_t get_shift<1>()
|
||||
{
|
||||
return (0);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -25,16 +25,4 @@ struct float_equal_zero
|
||||
};
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
static constexpr __device__ index_t get_shift()
|
||||
{
|
||||
return (get_shift<N / 2>() + 1);
|
||||
};
|
||||
|
||||
template <>
|
||||
constexpr __device__ index_t get_shift<1>()
|
||||
{
|
||||
return (0);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user