Merge commit 'cafaeb6b7bac4e18b0a5341cd14f54224292a0c9' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 15:12:59 +00:00
parent 83b2a1d876
commit 26e9ec020f
29 changed files with 1970 additions and 282 deletions

View File

@@ -26,7 +26,8 @@ struct GroupedConvBwdWeightKernelArgs
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC>;
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -84,9 +85,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NWGK
group_stride_b = args.C_; // B: In NWGC
group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -95,7 +98,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
template <
@@ -160,9 +170,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NHWGK
group_stride_b = args.C_; // B: In NHWGC
group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -171,7 +183,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
template <
@@ -243,9 +262,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NDHWGK
group_stride_b = args.C_; // B: In NDHWGC
group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -254,7 +275,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
using ABCGridDescs = remove_cvref_t<
@@ -279,6 +307,7 @@ struct GroupedConvBwdWeightKernelArgs
index_t GemmN;
index_t GemmK;
index_t GemmBatch;
index_t NumGroupsPerBatch;
const void* out_ptr;
const void* in_ptr;
@@ -317,10 +346,9 @@ struct GroupedConvBwdWeightKernelArgs
/// the policy is responsible for definition of all necessary data layouts and thread's
/// work distribution.
///
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
/// the
/// output data tile to be calculated. It determines the
/// the output data tile to be calculated. It determines the
/// workgroup to data relationship (or in other words - which
/// data would be processed and calculated by which workgroup).
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
@@ -382,8 +410,12 @@ struct GroupedConvolutionBackwardWeightKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
if (NumGroupsToMerge > 1)
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName(), "merge", NumGroupsToMerge);
else
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
// clang-format on
}
@@ -402,6 +434,12 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
}
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
}
@@ -442,11 +480,14 @@ struct GroupedConvolutionBackwardWeightKernel
{
return [&]() {
if(kargs.k_batch > 1)
hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
0,
kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
sizeof(WeiDataType),
s.stream_id_));
{
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
// since we require that ConvG % NumGroupsPerBatch == 0.
const auto wei_size =
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
hipGetErrorString(
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
}
};
}
@@ -527,7 +568,8 @@ struct GroupedConvolutionBackwardWeightKernel
// Check access per C
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
"input image!");
return false;
}
}
@@ -559,7 +601,8 @@ struct GroupedConvolutionBackwardWeightKernel
{
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
{
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
CK_TILE_ERROR("Conv K is not a multiple of vector store size "
"for output image!");
return false;
}
}
@@ -569,6 +612,18 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
{
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
{
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
}
return true;
}
@@ -654,6 +709,16 @@ struct GroupedConvolutionBackwardWeightKernel
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
}
/**
* @brief Create views to the data that each workgroup will process.
*
* @param views padded views of A, B, D and C tensors
* @param i_m block m-index
* @param i_n block n-index
* @param i_k block k-index
*
* @return tuple of tile windows for A, B, D and C tensors
*/
template <typename PadView>
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const index_t i_m,
@@ -818,7 +883,6 @@ struct GroupedConvolutionBackwardWeightKernel
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)

View File

@@ -29,6 +29,7 @@ struct GroupedConvFwdKernelArgs
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge,
true>; // Split N enabled
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;