Grouped Conv Bwd Weight Direct Load (#3648)

* Grouped Conv Bwd Weight Direct Load

* Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp

* Implement group merging for bwd_weight and add instances

* Link direct load instances

* builder fixes

* fix

* fixes

* fix

---------

Co-authored-by: Graner, Johannes <johannes.graner@amd.com>
This commit is contained in:
Bartłomiej Kocot
2026-01-28 22:31:54 +01:00
committed by GitHub
parent 654bec3362
commit 83b58bb0c3
18 changed files with 578 additions and 194 deletions

View File

@@ -82,23 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
#if defined(__gfx950__)
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
#endif
}
else
{
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
}
#else
ignore = karg;
@@ -236,7 +261,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
@@ -287,7 +314,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
NPerBlock,
K1Number,
K0PerBlock / K1Number,
1 /*NumGroupsToMerge*/,
NumGroupsToMerge,
ConvBackwardWeightSpecialization>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
@@ -371,6 +398,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)
: ABlockTransferSrcScalarPerVector;
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned =
BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
? 4 / sizeof(BDataType)
: BBlockTransferSrcScalarPerVector;
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
@@ -399,7 +436,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
@@ -407,7 +444,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
@@ -418,7 +455,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
DirectLoad>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -653,15 +691,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
// A/B/C Batch Stride (multiply by NumGroupsToMerge for group merging)
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
std::multiplies<>{}) *
NumGroupsToMerge;
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -743,7 +782,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_);
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge);
float ave_time = 0;
@@ -1367,6 +1406,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
#endif
// check device
if constexpr(DirectLoad)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
// Check that NumGroupsToMerge divides Conv_G evenly
if constexpr(NumGroupsToMerge > 1)
{
if(arg.Conv_G_ % NumGroupsToMerge != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_="
<< arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge
<< std::endl;
}
return false;
}
}
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
@@ -1617,8 +1680,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"
<< "<"
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
if constexpr(DirectLoad) {
str << "_DirectLoad";
}
str << "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "

View File

@@ -567,6 +567,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)