mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Grouped Conv Bwd Weight Direct Load (#3648)
* Grouped Conv Bwd Weight Direct Load * Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp * Implement group merging for bwd_weight and add instances * Link direct load instances * builder fixes * fix * fixes * fix --------- Co-authored-by: Graner, Johannes <johannes.graner@amd.com>
This commit is contained in:
@@ -35,7 +35,7 @@ template <typename T>
|
||||
concept BwdXdlV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaAlgorithmBase =
|
||||
|
||||
@@ -53,7 +53,9 @@ template <ck::index_t NDimSpatial,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
typename ComputeTypeB,
|
||||
bool DirectLoad,
|
||||
index_t NumGroupsToMerge>
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
@@ -109,7 +111,9 @@ template <ck::index_t NDimSpatial,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA_,
|
||||
typename ComputeTypeB_>
|
||||
typename ComputeTypeB_,
|
||||
bool DirectLoad,
|
||||
index_t NumGroupsToMerge>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
|
||||
NDimSpatial,
|
||||
InLayout_,
|
||||
@@ -153,7 +157,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA_,
|
||||
ComputeTypeB_>>
|
||||
ComputeTypeB_,
|
||||
DirectLoad,
|
||||
NumGroupsToMerge>>
|
||||
{
|
||||
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
@@ -241,6 +247,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
using ComputeTypeA = ComputeTypeA_;
|
||||
using ComputeTypeB = ComputeTypeB_;
|
||||
|
||||
static constexpr bool kDirectLoad = DirectLoad;
|
||||
static constexpr index_t kNumGroupsToMerge = NumGroupsToMerge;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
@@ -302,6 +311,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
|
||||
oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41.
|
||||
oss << "," << detail::type_name<ComputeTypeA>(); // 42.
|
||||
oss << "," << detail::type_name<ComputeTypeB>(); // 43.
|
||||
oss << "," << kDirectLoad; // 44.
|
||||
oss << "," << kNumGroupsToMerge; // 45.
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
|
||||
@@ -32,7 +32,8 @@ constexpr auto ALGORITHM =
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
@@ -632,7 +632,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
|
||||
BwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_>;
|
||||
BlockGemm_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
|
||||
@@ -69,6 +69,8 @@ std::string expected_str =
|
||||
",v1" // BlkGemmPipelineVer
|
||||
",fp16" // ComputeTypeA
|
||||
",fp16" // ComputeTypeB
|
||||
",0" // DirectLoad
|
||||
",1" // NumGroupsToMerge
|
||||
">";
|
||||
|
||||
// Test describe() through base class pointer for XDL V3 variant
|
||||
|
||||
Reference in New Issue
Block a user