mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4421 (commit 5bb5769)
[CK] Unify the grouped convolution gridwise Run() functions (#4421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation There are currently three different grouped convolution related Run() function overloads that exist in `gridwise_gemm_wmma_cshuffle_v3.hpp`. These are used for the different types of grouped convolution: Forward, Backward weights, and Backward data. The functions are very similar and should be unified to a single `Run()` function for all types of grouped convolution. ## Technical Details The three old `Run<>()` functions were replaced with a single unified function. The new `Run<>()` function is run from device implementations: - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 - DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 - DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 The DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor implementation uses a different `Run<>()` overload and was therefore not modified. ## Test Plan Run the following grouped convolution tests on `gfx1201`, as this architecture is WMMA-capable: - `test_grouped_convnd_fwd` - `test_grouped_convnd_bwd_weight` - `test_grouped_convnd_bwd_data` Compilation and testing were also executed on `gfx1100` to avoid CI problems. ## Test Result First part (unification of `Run<>()` function): All tests successful. Second part (integration of single `Run<>()` function as a direct call): All tests successful. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6f0ecf361e
commit
d8ee107a47
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
@@ -100,17 +101,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
|
||||
{
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_DATA,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
|
||||
ComputePtrOffsetOfBatch,
|
||||
ComputePtrOffsetOfN,
|
||||
0,
|
||||
HasMainKBlockLoopInAllGemm,
|
||||
EGlobalMemoryDataOperation,
|
||||
CTranspose,
|
||||
TailNum>(
|
||||
TailNum,
|
||||
decltype(epilogue_args)>(
|
||||
p_shared,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
@@ -127,17 +131,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_DATA,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
|
||||
ComputePtrOffsetOfBatch,
|
||||
ComputePtrOffsetOfN,
|
||||
0,
|
||||
true,
|
||||
EGlobalMemoryDataOperation,
|
||||
CTranspose,
|
||||
TailNum>(
|
||||
TailNum,
|
||||
decltype(epilogue_args)>(
|
||||
p_shared,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
@@ -152,17 +160,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_AK1,
|
||||
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_DATA,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(gemm_kernel_args[group_id].block_2_ctile_map_),
|
||||
ComputePtrOffsetOfBatch,
|
||||
ComputePtrOffsetOfN,
|
||||
0,
|
||||
false,
|
||||
EGlobalMemoryDataOperation,
|
||||
CTranspose,
|
||||
TailNum>(
|
||||
TailNum,
|
||||
decltype(epilogue_args)>(
|
||||
p_shared,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -28,6 +29,7 @@
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
@@ -71,23 +73,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
ck::Tuple<>, // Empty tuple
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(block_2_ctile_map_),
|
||||
ComputePtrOffsetOfBatch,
|
||||
ComputePtrOffsetOfBatch, // placeholder
|
||||
1,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
compute_ptr_offset_of_batch,
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
false,
|
||||
TailNum,
|
||||
decltype(epilogue_args)>(
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ck::Tuple<>(), // placeholder
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map_,
|
||||
compute_ptr_offset_of_batch,
|
||||
ComputePtrOffsetOfBatch{}, // placeholder
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -29,6 +30,7 @@
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
@@ -71,23 +73,35 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
ck::Tuple<>, // Empty tuple
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(block_2_ctile_map_),
|
||||
ComputePtrOffsetOfBatch,
|
||||
ComputePtrOffsetOfBatch, // placeholder
|
||||
NumGroupsToMerge,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
compute_ptr_offset_of_batch,
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
false,
|
||||
TailNum,
|
||||
decltype(epilogue_args)>(
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ck::Tuple<>(), // placeholder
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map_,
|
||||
compute_ptr_offset_of_batch,
|
||||
ComputePtrOffsetOfBatch{}, // placeholder
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
@@ -71,23 +72,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
ck::Tuple<>, // Empty tuple
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(block_2_ctile_map_),
|
||||
ComputePtrOffsetOfBatch,
|
||||
ComputePtrOffsetOfBatch, // placeholder
|
||||
1,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
compute_ptr_offset_of_batch,
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
false,
|
||||
TailNum,
|
||||
decltype(epilogue_args)>(
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ck::Tuple<>(), // placeholder
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map_,
|
||||
compute_ptr_offset_of_batch,
|
||||
ComputePtrOffsetOfBatch{}, // placeholder
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
|
||||
@@ -105,24 +105,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
|
||||
|
||||
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::FORWARD,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(block_2_ctile_map_),
|
||||
ComputePtrOffset,
|
||||
ComputePtrOffset,
|
||||
0,
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
compute_ptr_offset_of_batch,
|
||||
compute_ptr_offset_of_n,
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
false,
|
||||
TailNum,
|
||||
decltype(epilogue_args)>(
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map_,
|
||||
compute_ptr_offset_of_batch,
|
||||
compute_ptr_offset_of_n,
|
||||
num_k_per_block,
|
||||
karg,
|
||||
epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user