mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_Tile] Merge multiple convolution groups into a single GEMM batch (#2986)
* Fix compilation of the grouped conv examples. * Fix grouped conv bwd weight example output in CK Tile. * Add number of groups to merge to ck tile grouped gemm example. * Initial set of tests for TransformConvBwdWeightToGemm. * Added unit tests for TransformConvBwdWeightToGemm conv groups are merged. * WIP: Tensor transformations. * Add unit tests for coordinate transforms. * Fully working conv group merging for TransformConvBwdWeightToGemm. * WIP: Merged conv groups offset calculation. * Adde unit tests for tensor view. * WIP: Merged conv groups epilogue. * Enable running multiple conv groups per batch. * Add tests for tile_distribution_encoding. * Change example to match optimally depthwise convolution with merged groups. * Add more tests for tensor view. * Integration test for reading diagonal blocks from grouped distributed tensor. * Improved integration test. * Improve test for accessing diagonal blocks. * Added integration test for cshuffle epilogue LDS tile distribution. * Add more logging. * Increase the max number of reported errors. * WIP: merged conv groups GEMM epilogue changes. * LDS to global memory copy. * Fix tile window size for c block. * Integration test for CShuffle epilogue. * Improved CShuffle test. * WIP: Separate epilogue for merged conv groups. * Tile example parameters changes to match depthwise conv. * Offset fixes. * Epilogue fixes. * Working baseline for depthwise covolution with merged conv groups. * Fix build. * Initial unit tests for tensor descriptor. * Add one more unit test for tensor view. * WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding. * Fully functional LDS to global mem transfer using tensor descriptor and tile distribution encoding. * Add more comments, disable debug code. * Remove debug and other dead code. * Code clean-up for bwd tensor transformations. * Enable running multiple GEMM batches of merged conv groups. * Add compile check for assumed row-mjor layout. * Fix strides in 1D conv to gemm transformation. * WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases. * Fix case k > 1 and c=1. * Remove debug code. * Make MPerGroup and NPerGroup template parameters. * Add additional check for non-supported c > 1 case. * WIP: Put back the generic tensor descriptors for convolutions. * Fix tensor descriptors. * Remove the obsolete template parameters. * Add more instances. * Fix bugs in merged conv groups tensor descriptors. * Fix tensor descriptors for merged conv groups when K > 1. * Remove debug output. * Remove dead code. * Fix merge conflicts. * Code clean-up. * Remove unused code. * Run clang-formatting. * Remove debug prints and obsolete tests. * Check that number of convolution groups is multiple of merged groups. * Fix build after removing obsolete functionality. * Remove obsolete enumeration. * Fix new unit projects. * Remove unnecessary includes. * Fix passing the number of merged groups. * Remove unrelated tests. * Fix IsSupportedArgument for bwd weight conv kernel. * Fix clang formatting. * Fix the bwd weight conv to gemm mapping for num merged groups > 1. * GEMM config for conv group merging. * Fix clang-formatting. * Remove obsolete comment. * Fix typos in comment strings. * Increase the max number of reported errors when testing against reference implementation. * Rename gemm_config to conv_config. * Rename GemmConfig to ConvConfig and move NumGroupsToMerge into ConvConfig. * Change num_groups_to_merge to a boolean flag in the ck tile grouped conv example. * Run clang-format. * Add number of merged groups into kernel name string. * Remove group merging flag from CK Tile grouped conv example.
This commit is contained in:
@@ -26,7 +26,8 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC>;
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
GroupedConvTraitsType_::NumGroupsToMerge>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
template <
|
||||
@@ -84,9 +85,11 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NWGK
|
||||
group_stride_b = args.C_; // B: In NWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
@@ -95,7 +98,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -160,9 +170,11 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NHWGK
|
||||
group_stride_b = args.C_; // B: In NHWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
@@ -171,7 +183,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -243,9 +262,11 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NDHWGK
|
||||
group_stride_b = args.C_; // B: In NDHWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
@@ -254,7 +275,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<
|
||||
@@ -279,6 +307,7 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
index_t NumGroupsPerBatch;
|
||||
|
||||
const void* out_ptr;
|
||||
const void* in_ptr;
|
||||
@@ -317,10 +346,9 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
|
||||
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
|
||||
/// the
|
||||
/// output data tile to be calculated. It determines the
|
||||
/// the output data tile to be calculated. It determines the
|
||||
/// workgroup to data relationship (or in other words - which
|
||||
/// data would be processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
@@ -382,8 +410,12 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
|
||||
if (NumGroupsToMerge > 1)
|
||||
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName(), "merge", NumGroupsToMerge);
|
||||
else
|
||||
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -402,6 +434,12 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
|
||||
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
|
||||
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
|
||||
}
|
||||
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
|
||||
@@ -442,11 +480,14 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
return [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
|
||||
0,
|
||||
kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
|
||||
sizeof(WeiDataType),
|
||||
s.stream_id_));
|
||||
{
|
||||
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
|
||||
// since we require that ConvG % NumGroupsPerBatch == 0.
|
||||
const auto wei_size =
|
||||
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -527,7 +568,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
// Check access per C
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
|
||||
"input image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -559,7 +601,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size "
|
||||
"for output image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -569,6 +612,18 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -654,6 +709,16 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create views to the data that each workgroup will process.
|
||||
*
|
||||
* @param views padded views of A, B, D and C tensors
|
||||
* @param i_m block m-index
|
||||
* @param i_n block n-index
|
||||
* @param i_k block k-index
|
||||
*
|
||||
* @return tuple of tile windows for A, B, D and C tensors
|
||||
*/
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const index_t i_m,
|
||||
@@ -818,7 +883,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
|
||||
@@ -29,6 +29,7 @@ struct GroupedConvFwdKernelArgs
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
GroupedConvTraitsType_::NumGroupsToMerge,
|
||||
true>; // Split N enabled
|
||||
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
Reference in New Issue
Block a user