From 1b042bb70dde3518e8d43839514aa255d89d6f40 Mon Sep 17 00:00:00 2001 From: Satyanvesh Dittakavi Date: Tue, 10 Jun 2025 04:11:14 +0000 Subject: [PATCH] Do not use warpSize as compile time constant as it is removed --- include/ck/ck.hpp | 6 ++++++ .../gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp | 4 ++-- include/ck/utility/workgroup_synchronization.hpp | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 55f5620616..ebb51506b1 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -245,6 +245,12 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) namespace ck { +#if defined(__GFX9__) + __device__ static constexpr int WarpSize = 64; +#else + __device__ static constexpr int WarpSize = 32; +#endif + enum struct InMemoryDataOperationEnum { Set, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 9acfd00858..c9cc9c107a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -634,7 +634,7 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/utility/workgroup_synchronization.hpp b/include/ck/utility/workgroup_synchronization.hpp index 24858fdbdc..af5b0808fb 100644 --- a/include/ck/utility/workgroup_synchronization.hpp +++ b/include/ck/utility/workgroup_synchronization.hpp @@ -32,7 +32,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits) // all the workgroups in the synchronization group is supposed to call this function static __device__ void gms_barrier(int* p_control_bits) { - constexpr int mask = warpSize - 1; + constexpr int mask = WarpSize - 1; if((threadIdx.x & mask) == 0) {