mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Merge commit '6635d1bb888e3f51ec1125ff7d6f54a2ec054a10' into develop
This commit is contained in:
@@ -9,8 +9,11 @@ namespace ck {
|
||||
|
||||
__host__ __device__ constexpr index_t get_warp_size()
|
||||
{
|
||||
// warpSize is defined by HIP
|
||||
return warpSize;
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
@@ -50,8 +50,11 @@ enum struct memory_operation_enum : std::uint16_t
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
// warpSize is defined by HIP
|
||||
return warpSize;
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
Reference in New Issue
Block a user