mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136)
* Merge fwd conv groups in CK Tile. * Fix building CK fwd convs. * Add number of merged groups to conv fwd kernel name. * Get number of merged groups from conv config. * Rename GemmConfig to ConvConfig. * Clean-up TODOs. * Check that number of conv groups must be divisible by the number of merged groups. * Improve error handling in the conv fwd example. * Fix clang-format. * Fix group offsets. * Fix merge problem. * Address feedback from code review. * Fix clang-formatting.
This commit is contained in:
@@ -643,8 +643,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -28,7 +28,6 @@ namespace ck_tile {
|
||||
template <typename GroupedConvTraitsType_, typename CDElementwise_>
|
||||
struct GroupedConvFwdKernelArgs
|
||||
{
|
||||
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
@@ -40,6 +39,10 @@ struct GroupedConvFwdKernelArgs
|
||||
using CDElementwise = CDElementwise_;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
static_assert(!GroupedConvTraitsType_::ExplicitGemm ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
|
||||
"Explicit GEMM does not support merging convolution groups!");
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType_::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
|
||||
@@ -71,11 +74,6 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
// GemmM will be set after Split-N calculation
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0];
|
||||
GemmBatch = args.G_;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -100,13 +98,14 @@ struct GroupedConvFwdKernelArgs
|
||||
c_grid_desc_m_n =
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.C_ * NumGroupsToMerge;
|
||||
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
group_stride_c = args.K_ * NumGroupsToMerge;
|
||||
|
||||
// Initialize Split-N support fields for 1D convolution (NWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
@@ -121,8 +120,20 @@ struct GroupedConvFwdKernelArgs
|
||||
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
|
||||
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0];
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
|
||||
<< ", number of N splits: " << n_splits
|
||||
<< ", input_batch_stride: " << input_batch_stride
|
||||
<< ", output_batch_stride: " << output_batch_stride
|
||||
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -163,11 +174,6 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
// Note: GemmM will be set after Split-N calculation
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
|
||||
GemmBatch = args.G_;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -192,13 +198,14 @@ struct GroupedConvFwdKernelArgs
|
||||
c_grid_desc_m_n =
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.C_ * NumGroupsToMerge;
|
||||
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
group_stride_c = args.K_ * NumGroupsToMerge;
|
||||
|
||||
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
@@ -213,8 +220,20 @@ struct GroupedConvFwdKernelArgs
|
||||
output_batch_stride =
|
||||
args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
|
||||
<< ", number of N splits: " << n_splits
|
||||
<< ", input_batch_stride: " << input_batch_stride
|
||||
<< ", output_batch_stride: " << output_batch_stride
|
||||
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -262,12 +281,6 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
// Note: GemmM will be set after Split-N calculation
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
|
||||
args.filter_spatial_lengths_[2];
|
||||
GemmBatch = args.G_;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -292,13 +305,14 @@ struct GroupedConvFwdKernelArgs
|
||||
c_grid_desc_m_n =
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.C_ * NumGroupsToMerge;
|
||||
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
group_stride_c = args.K_ * NumGroupsToMerge;
|
||||
|
||||
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
@@ -313,11 +327,21 @@ struct GroupedConvFwdKernelArgs
|
||||
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
|
||||
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
|
||||
args.output_spatial_lengths_[2];
|
||||
}
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
|
||||
<< ", number of N splits: " << n_splits
|
||||
<< ", input_batch_stride: " << input_batch_stride
|
||||
<< ", output_batch_stride: " << output_batch_stride
|
||||
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
|
||||
}
|
||||
}
|
||||
using AGridDescMK = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
|
||||
@@ -343,6 +367,7 @@ struct GroupedConvFwdKernelArgs
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
index_t NumGroupsToMerge;
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
@@ -567,13 +592,25 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_forward",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
"gemm",
|
||||
GemmPipeline::GetName(),
|
||||
"epilogue",
|
||||
EpiloguePipeline::GetName());
|
||||
if (NumGroupsToMerge > 1) {
|
||||
return concat('_', "grouped_convolution_forward",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
"gemm",
|
||||
GemmPipeline::GetName(),
|
||||
"epilogue",
|
||||
EpiloguePipeline::GetName(),
|
||||
"merge",
|
||||
NumGroupsToMerge);
|
||||
} else {
|
||||
return concat('_', "grouped_convolution_forward",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
"gemm",
|
||||
GemmPipeline::GetName(),
|
||||
"epilogue",
|
||||
EpiloguePipeline::GetName());
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -742,6 +779,16 @@ struct GroupedConvolutionForwardKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -470,10 +470,10 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
{
|
||||
IndexType NStrideTensorA_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -701,11 +701,11 @@ struct TransformConvFwdToGemm
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
|
||||
{
|
||||
IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_;
|
||||
IndexType HiStride_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -960,12 +960,12 @@ struct TransformConvFwdToGemm
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
|
||||
{
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
|
||||
IndexType HiStride_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -1289,9 +1289,9 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeBDescriptor_N_K() const
|
||||
{
|
||||
IndexType CStrideTensorB_ = 1;
|
||||
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
|
||||
IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
|
||||
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
|
||||
IndexType CStrideTensorB_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
@@ -1356,10 +1356,10 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
@@ -1417,11 +1417,11 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_;
|
||||
IndexType HoStride_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
@@ -1482,12 +1482,12 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
|
||||
IndexType HoStride_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
|
||||
Reference in New Issue
Block a user