Fix grid size calc for bwd wei (#2226)

[ROCm/composable_kernel commit: 037764bbc6]
This commit is contained in:
Bartłomiej Kocot
2025-05-26 16:51:09 +02:00
committed by GitHub
parent 6538aae676
commit 52fb9c990f
4 changed files with 29 additions and 23 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y * Z;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;

View File

@@ -1,6 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -166,8 +166,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * X;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -365,8 +365,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -558,8 +558,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =

View File

@@ -346,8 +346,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -534,8 +534,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -737,8 +737,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * Z * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =