Remove usage of 'warpSize' variable as it has been deprecated (#2295)

* SWDEV-535598 - remove usage of 'warpSize' variable as it has been deprecated. Ideally get_warp_size() should not be constexpr but this is just a workaround

* SWDEV-535598 - remove comment from get_warp_size as constexpr is required for this repo

---------

Co-authored-by: Gerardo Hernandez <gerardo.hernandez@amd.com>

[ROCm/composable_kernel commit: 6635d1bb88]
This commit is contained in:
John Afaganis
2025-06-10 08:34:54 -06:00
committed by GitHub
parent 66767bf11b
commit 42ea095d98
2 changed files with 10 additions and 4 deletions

View File

@@ -9,8 +9,11 @@ namespace ck {
__host__ __device__ constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }

View File

@@ -50,8 +50,11 @@ enum struct memory_operation_enum : std::uint16_t
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }