Fix grid size calc for bwd wei (#2226)

This commit is contained in:
Bartłomiej Kocot
2025-05-26 16:51:09 +02:00
committed by GitHub
parent ece38b9d7a
commit 037764bbc6
4 changed files with 29 additions and 23 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y * Z;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;