mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-10 18:39:59 +00:00
54 lines
1.2 KiB
C++
54 lines
1.2 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck/ck.hpp"
|
|
|
|
namespace ck {
|
|
|
|
__device__ constexpr index_t get_warp_size()
|
|
{
|
|
#if defined(__HIP_DEVICE_COMPILE__)
|
|
#if defined(__GFX9__)
|
|
return 64;
|
|
#else
|
|
return 32;
|
|
#endif
|
|
#else
|
|
return 64;
|
|
#endif
|
|
}
|
|
|
|
inline __host__ index_t get_warp_size()
|
|
{
|
|
#if !(defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC))
|
|
int device = 0;
|
|
int result = 0;
|
|
auto status = hipGetDevice(&device);
|
|
if(status == hipSuccess)
|
|
{
|
|
status = hipDeviceGetAttribute(&result, hipDeviceAttributeWarpSize, device);
|
|
if(status == hipSuccess)
|
|
{
|
|
return result;
|
|
}
|
|
}
|
|
#endif
|
|
return 64;
|
|
}
|
|
|
|
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
|
|
|
__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
|
|
|
|
__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); }
|
|
|
|
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
|
|
|
__device__ index_t get_grid_size() { return gridDim.x; }
|
|
|
|
__device__ index_t get_block_size() { return blockDim.x; }
|
|
|
|
} // namespace ck
|