[rocm-libraries] ROCm/rocm-libraries#4421 (commit 5bb5769)

[CK] Unify the grouped convolution gridwise Run() functions
 (#4421)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

There are currently three different grouped convolution related Run()
function overloads that exist in `gridwise_gemm_wmma_cshuffle_v3.hpp`.
These are used for the different types of grouped convolution: Forward,
Backward weights, and Backward data.
The functions are very similar and should be unified to a single `Run()`
function for all types of grouped convolution.

## Technical Details

The three old `Run<>()` functions were replaced with a single unified
function.
The new `Run<>()` function is run from device implementations:

-  DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3

-  DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3

-  DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3

-  DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3

-  DeviceGroupedConvBwdWeight_Wmma_CShuffleV3

The DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
implementation uses a different `Run<>()` overload and was therefore not
modified.

## Test Plan

Run the following grouped convolution tests on `gfx1201`, as this
architecture is WMMA-capable:

- `test_grouped_convnd_fwd`

- `test_grouped_convnd_bwd_weight`

- `test_grouped_convnd_bwd_data`

Compilation and testing were also executed on `gfx1100` to avoid CI
problems.

## Test Result

First part (unification of `Run<>()` function): All tests successful.

Second part (integration of single `Run<>()` function as a direct call):
All tests successful.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
JP-Fernando
2026-03-11 16:40:12 +00:00
committed by assistant-librarian[bot]
parent 6f0ecf361e
commit d8ee107a47
6 changed files with 255 additions and 343 deletions

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck/ck.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
@@ -100,17 +101,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_DATA,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
ComputePtrOffsetOfBatch,
ComputePtrOffsetOfN,
0,
HasMainKBlockLoopInAllGemm,
EGlobalMemoryDataOperation,
CTranspose,
TailNum>(
TailNum,
decltype(epilogue_args)>(
p_shared,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
@@ -127,17 +131,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_DATA,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
ComputePtrOffsetOfBatch,
ComputePtrOffsetOfN,
0,
true,
EGlobalMemoryDataOperation,
CTranspose,
TailNum>(
TailNum,
decltype(epilogue_args)>(
p_shared,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
@@ -152,17 +160,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_DATA,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
ComputePtrOffsetOfBatch,
ComputePtrOffsetOfN,
0,
false,
EGlobalMemoryDataOperation,
CTranspose,
TailNum>(
TailNum,
decltype(epilogue_args)>(
p_shared,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,

View File

@@ -7,6 +7,7 @@
#include <numeric>
#include <sstream>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
@@ -28,6 +29,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/utility/tuple.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
@@ -71,23 +73,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
ck::Tuple<>, // Empty tuple
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(block_2_ctile_map_),
ComputePtrOffsetOfBatch,
ComputePtrOffsetOfBatch, // placeholder
1,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
compute_ptr_offset_of_batch,
num_k_per_block,
karg,
epilogue_args);
false,
TailNum,
decltype(epilogue_args)>(
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ck::Tuple<>(), // placeholder
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map_,
compute_ptr_offset_of_batch,
ComputePtrOffsetOfBatch{}, // placeholder
num_k_per_block,
karg,
epilogue_args);
#if defined(__gfx11__)
}

View File

@@ -7,6 +7,7 @@
#include <numeric>
#include <sstream>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
@@ -29,6 +30,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/utility/tuple.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
@@ -71,23 +73,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
ck::Tuple<>, // Empty tuple
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(block_2_ctile_map_),
ComputePtrOffsetOfBatch,
ComputePtrOffsetOfBatch, // placeholder
NumGroupsToMerge,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
compute_ptr_offset_of_batch,
num_k_per_block,
karg,
epilogue_args);
false,
TailNum,
decltype(epilogue_args)>(
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ck::Tuple<>(), // placeholder
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map_,
compute_ptr_offset_of_batch,
ComputePtrOffsetOfBatch{}, // placeholder
num_k_per_block,
karg,
epilogue_args);
#if defined(__gfx11__)
}
#endif

View File

@@ -7,6 +7,7 @@
#include <numeric>
#include <sstream>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
@@ -71,23 +72,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
ck::Tuple<>, // Empty tuple
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(block_2_ctile_map_),
ComputePtrOffsetOfBatch,
ComputePtrOffsetOfBatch, // placeholder
1,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
compute_ptr_offset_of_batch,
num_k_per_block,
karg,
epilogue_args);
false,
TailNum,
decltype(epilogue_args)>(
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ck::Tuple<>(), // placeholder
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map_,
compute_ptr_offset_of_batch,
ComputePtrOffsetOfBatch{}, // placeholder
num_k_per_block,
karg,
epilogue_args);
#if defined(__gfx11__)
}

View File

@@ -105,24 +105,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::FORWARD,
decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(block_2_ctile_map_),
ComputePtrOffset,
ComputePtrOffset,
0,
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
compute_ptr_offset_of_batch,
compute_ptr_offset_of_n,
num_k_per_block,
karg,
epilogue_args);
false,
TailNum,
decltype(epilogue_args)>(
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map_,
compute_ptr_offset_of_batch,
compute_ptr_offset_of_n,
num_k_per_block,
karg,
epilogue_args);
#if defined(__gfx11__)
}
#endif