mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Grouped Conv Bwd Weight Direct Load (#3648)
* Grouped Conv Bwd Weight Direct Load * Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp * Implement group merging for bwd_weight and add instances * Link direct load instances * builder fixes * fix * fixes * fix --------- Co-authored-by: Graner, Johannes <johannes.graner@amd.com>
This commit is contained in:
@@ -82,23 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
|
||||
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
if constexpr(GridwiseGemm::DirectLoadEnabled)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -236,7 +261,9 @@ template <ck::index_t NDimSpatial,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool DirectLoad = false,
|
||||
index_t NumGroupsToMerge = 1>
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
@@ -287,7 +314,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
NPerBlock,
|
||||
K1Number,
|
||||
K0PerBlock / K1Number,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
NumGroupsToMerge,
|
||||
ConvBackwardWeightSpecialization>{};
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -371,6 +398,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
|
||||
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
|
||||
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
|
||||
? 4 / sizeof(ADataType)
|
||||
: ABlockTransferSrcScalarPerVector;
|
||||
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned =
|
||||
BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
|
||||
? 4 / sizeof(BDataType)
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
@@ -399,7 +436,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsAddExtraM,
|
||||
@@ -407,7 +444,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsAddExtraN,
|
||||
@@ -418,7 +455,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
ComputeTypeB,
|
||||
DirectLoad>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
@@ -653,15 +691,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
|
||||
// A/B/C Batch Stride (multiply by NumGroupsToMerge for group merging)
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
std::multiplies<>{}) *
|
||||
NumGroupsToMerge;
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
@@ -743,7 +782,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
|
||||
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_);
|
||||
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -1367,6 +1406,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
#endif
|
||||
|
||||
// check device
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check that NumGroupsToMerge divides Conv_G evenly
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
if(arg.Conv_G_ % NumGroupsToMerge != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_="
|
||||
<< arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
@@ -1617,8 +1680,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"
|
||||
<< "<"
|
||||
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
|
||||
if constexpr(DirectLoad) {
|
||||
str << "_DirectLoad";
|
||||
}
|
||||
|
||||
str << "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
|
||||
@@ -567,6 +567,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
|
||||
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
|
||||
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
|
||||
? 4 / sizeof(ADataType)
|
||||
|
||||
Reference in New Issue
Block a user