Grouped Conv Bwd Weight Direct Load (#3648)

* Grouped Conv Bwd Weight Direct Load

* Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp

* Implement group merging for bwd_weight and add instances

* Link direct load instances

* builder fixes

* fix

* fixes

* fix

---------

Co-authored-by: Graner, Johannes <johannes.graner@amd.com>
This commit is contained in:
Bartłomiej Kocot
2026-01-28 22:31:54 +01:00
committed by GitHub
parent 654bec3362
commit 83b58bb0c3
18 changed files with 578 additions and 194 deletions

View File

@@ -35,7 +35,7 @@ template <typename T>
concept BwdXdlV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
template <typename T>
concept BwdWmmaAlgorithmBase =

View File

@@ -53,7 +53,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB>
typename ComputeTypeB,
bool DirectLoad,
index_t NumGroupsToMerge>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
} // namespace ck::tensor_operation::device
@@ -109,7 +111,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA_,
typename ComputeTypeB_>
typename ComputeTypeB_,
bool DirectLoad,
index_t NumGroupsToMerge>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
NDimSpatial,
InLayout_,
@@ -153,7 +157,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA_,
ComputeTypeB_>>
ComputeTypeB_,
DirectLoad,
NumGroupsToMerge>>
{
/// @brief Tag type identifying this device kernel variant
@@ -241,6 +247,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
static constexpr bool kDirectLoad = DirectLoad;
static constexpr index_t kNumGroupsToMerge = NumGroupsToMerge;
// Static member function to generate instance string
static std::string instance_string()
{
@@ -302,6 +311,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41.
oss << "," << detail::type_name<ComputeTypeA>(); // 42.
oss << "," << detail::type_name<ComputeTypeB>(); // 43.
oss << "," << kDirectLoad; // 44.
oss << "," << kNumGroupsToMerge; // 45.
oss << ">";
return oss.str();

View File

@@ -32,7 +32,8 @@ constexpr auto ALGORITHM =
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;

View File

@@ -632,7 +632,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_>;
BlockGemm_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
ConvAlgorithmTemplate<ThreadBlock_,

View File

@@ -69,6 +69,8 @@ std::string expected_str =
",v1" // BlkGemmPipelineVer
",fp16" // ComputeTypeA
",fp16" // ComputeTypeB
",0" // DirectLoad
",1" // NumGroupsToMerge
">";
// Test describe() through base class pointer for XDL V3 variant