From 52fb9c990fa0f7bf0fe543e9999ee45ee9dc395a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 26 May 2025 16:51:09 +0200 Subject: [PATCH] Fix grid size calc for bwd wei (#2226) [ROCm/composable_kernel commit: 037764bbc62a11e9fddcfb959950c98b346b2901] --- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 20 ++++++++++++------- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 6 +++--- .../transform_conv_bwd_weight_to_gemm.hpp | 14 ++++++------- .../transform_conv_bwd_weight_to_gemm_v2.hpp | 12 +++++------ 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 57c4b1a5cf..33b6d7c585 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { const index_t GemmM = K; const index_t GemmN = C * X; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; return transform_tensor_descriptor( wei_grid_desc, @@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { const index_t GemmM = K; const index_t GemmN = C * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; return transform_tensor_descriptor( wei_grid_desc, @@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { const index_t GemmM = K; const index_t GemmN = C * X * Y * Z; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; return transform_tensor_descriptor( wei_grid_desc, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 0831b754c8..e9e02eae81 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle const index_t GemmM = K; const index_t GemmN = C * Z * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index c11bf845d0..bd3ab10802 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -166,8 +166,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmM = K; const index_t GemmN = C * X; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -365,8 +365,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmM = K; const index_t GemmN = C * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -558,8 +558,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmM = K; const index_t GemmN = C * Z * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index f34e0e59b3..b72ddb8243 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -346,8 +346,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmM = K * NumGroupsToMerge; const index_t GemmN = C * X * NumGroupsToMerge; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -534,8 +534,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmM = K * NumGroupsToMerge; const index_t GemmN = C * X * Y * NumGroupsToMerge; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -737,8 +737,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmM = K * NumGroupsToMerge; const index_t GemmN = C * Z * X * Y * NumGroupsToMerge; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 =