mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Fix grid size calc for bwd wei (#2226)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
@@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
@@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y * Z;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;
|
||||
|
||||
Reference in New Issue
Block a user