This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,981 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#endif
namespace ck_tile {
template <typename... Args>
CK_TILE_HOST void LogInfo(Args&&... args) noexcept
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_INFO(std::forward<Args>(args)...);
}
}
/// @brief The Grouped Convolution kernel device arguments.
template <typename GroupedConvTraitsType_>
struct GroupedConvBwdWeightKernelArgs
{
using ConvToGemmTransformer =
TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
typename InLay = typename GroupedConvTraitsType_::InLayout,
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
typename OutLay = typename GroupedConvTraitsType_::OutLayout,
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.output_spatial_lengths_[0])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
{
ds_ptr[d] = args.ds_ptr[d];
}
out_ptr = args.out_ptr;
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// tuple
auto grid_descs =
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
GroupedConvTraitsType_::NDimSpatial>();
a_grid_desc_k_m = grid_descs.at(number<0>{});
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
LogInfo("GemmM: ",
GemmM,
", GemmN: ",
GemmN,
", GemmK: ",
GemmK,
", GemmBatch: ",
GemmBatch,
", NumGroupsPerBatch: ",
NumGroupsPerBatch,
", k_batch: ",
k_batch);
}
template <
typename InLay = typename GroupedConvTraitsType_::InLayout,
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
typename OutLay = typename GroupedConvTraitsType_::OutLayout,
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.output_spatial_lengths_[0]),
static_cast<index_t>(args.output_spatial_lengths_[1])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
{
ds_ptr[d] = args.ds_ptr[d];
}
out_ptr = args.out_ptr;
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// tuple
auto grid_descs =
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
GroupedConvTraitsType_::NDimSpatial>();
a_grid_desc_k_m = grid_descs.at(number<0>{});
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
LogInfo("GemmM: ",
GemmM,
", GemmN: ",
GemmN,
", GemmK: ",
GemmK,
", GemmBatch: ",
GemmBatch,
", NumGroupsPerBatch: ",
NumGroupsPerBatch,
", k_batch: ",
k_batch);
}
template <
typename InLay = typename GroupedConvTraitsType_::InLayout,
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
typename OutLay = typename GroupedConvTraitsType_::OutLayout,
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.output_spatial_lengths_[0]),
static_cast<index_t>(args.output_spatial_lengths_[1]),
static_cast<index_t>(args.output_spatial_lengths_[2])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1]),
static_cast<index_t>(args.conv_filter_dilations_[2])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
{
ds_ptr[d] = args.ds_ptr[d];
}
out_ptr = args.out_ptr;
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// tuple
auto grid_descs =
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
GroupedConvTraitsType_::NDimSpatial>();
a_grid_desc_k_m = grid_descs.at(number<0>{});
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
k_batch = args.k_batch;
LogInfo("GemmM: ",
GemmM,
", GemmN: ",
GemmN,
", GemmK: ",
GemmK,
", GemmBatch: ",
GemmBatch,
", NumGroupsPerBatch: ",
NumGroupsPerBatch,
", k_batch: ",
k_batch);
}
using ABCGridDescs = remove_cvref_t<
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
using AGridDescKM = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
using BGridDescKN = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
using CGridDescMN = remove_cvref_t<decltype(ABCGridDescs{}[number<2>{}])>;
static constexpr index_t NonSpatialDims = 3;
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> in_g_n_c_wis_lengths;
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> wei_g_k_c_xs_lengths;
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> out_g_n_k_wos_lengths;
array<index_t, GroupedConvTraitsType_::NDimSpatial> conv_filter_strides;
array<index_t, GroupedConvTraitsType_::NDimSpatial> conv_filter_dilations;
array<index_t, GroupedConvTraitsType_::NDimSpatial> input_left_pads;
array<index_t, GroupedConvTraitsType_::NDimSpatial> input_right_pads;
index_t k_batch;
index_t GemmM;
index_t GemmN;
index_t GemmK;
index_t GemmBatch;
index_t NumGroupsPerBatch;
const void* out_ptr;
const void* in_ptr;
std::array<const void*, NumDTensor> ds_ptr;
void* wei_ptr;
AGridDescKM a_grid_desc_k_m;
BGridDescKN b_grid_desc_k_n;
CGridDescMN c_grid_desc_m_n;
long_index_t group_stride_a;
long_index_t group_stride_b;
long_index_t group_stride_c;
};
/// @brief The Grouped Convolution Backward Weight kernel template.
///
/// @paragraph Overview Overview
/// This class provides the grouped convolution backward weight kernel template. By
/// semantic division of Implicit GEMM algorithm into following parts we achieve
/// flexible, versatile and robust kernel implementation.
///
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
/// function call operator" which determines the work scope of each workgroup.
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
/// This is the place where each workgroup is loading data from global memory and
/// carrying out dot products.
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
/// responsible for storing results to global memory. This is also the place where
/// any additional operator fusion may take place.
///
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
/// internal details of those functional parts. You can think of it like both gemm and
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
/// the policy is responsible for definition of all necessary data layouts and thread's
/// work distribution.
///
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
/// the output data tile to be calculated. It determines the
/// workgroup to data relationship (or in other words - which
/// data would be processed and calculated by which workgroup).
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
/// multiplication. This class should provide implementation of
/// data loading from global memory and performing block-wise
/// matrix multiplication. You can think of it as a work done by
/// single workgroup point of view.
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
/// multiplication implementation. It is responsible for storing
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
/// the output C tensor in global memory.
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardWeightKernel
{
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using GemmALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using GemmBLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using GemmCLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using InLayout = remove_cvref_t<typename GroupedConvTraitsType_::InLayout>;
using WeiLayout = remove_cvref_t<typename GroupedConvTraitsType_::WeiLayout>;
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType_::OutLayout>;
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType_::DsLayout>;
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using OutDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using InDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
using WeiDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using GroupedConvBwdWeightKernelArgsSpecialized =
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
static constexpr bool IsSplitKSupported = true;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
"Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Not supported!");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{
static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardWeightKernel>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_tile_grouped_convolution_backward_weight.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<GroupedConvolutionBackwardWeightKernel>();
}
#endif
CK_TILE_HOST static constexpr auto
GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
return dim3(
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
{
LogInfo("MPerBlock: ",
number<TilePartitioner::MPerBlock>{},
", NPerBlock: ",
number<TilePartitioner::NPerBlock>{},
", KPerBlock: ",
number<TilePartitioner::KPerBlock>{});
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>;
// Negative k_batch value: split-K autodeduction.
if(kernel_args.k_batch < 0)
{
const auto optimal_split_k =
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
kernel_args);
kernel_args.k_batch = optimal_split_k;
}
return kernel_args;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if constexpr(GemmPipeline_::Async)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if(kargs.k_batch < 1)
{
LogInfo("k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
return false;
}
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
{
// The epilogue performs atomic add related to split-K using the ODataType.
// If the type is less accurate than float, large split-K values may lead to
// accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
if(kargs.k_batch > 128)
{
LogInfo("For epilogue output data type that is not float/double, we must have "
"k_batch <= 128.");
return false;
}
}
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
if(kargs.k_batch != 1)
{
LogInfo("Conditions not met for K_batch > 1: VectorSizeC must be a multiple of 2 "
"for fp16/bf16 when K_batch > 1.",
"Now k_batch is ",
kargs.k_batch,
", VectorSizeC is ",
GroupedConvTraitsType_::VectorSizeC);
return false;
}
}
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
{
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
kargs.GemmK,
", BlockGemmShape K: ",
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
", k_batch: ",
kargs.k_batch);
return false;
}
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
// check ConvSpecialization
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
const index_t ConvStride = kargs.conv_filter_strides[i];
const index_t LeftPad = kargs.input_left_pads[i];
const index_t RightPad = kargs.input_right_pads[i];
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
LogInfo("For Filter1x1Stride1Pad0 specialization, all spatial dimensions must "
"be 1, stride must be 1, and padding must be 0. Now for dimension ",
i,
": SpatialDim is ",
SpatialDim,
", ConvStride is ",
ConvStride,
", LeftPad is ",
LeftPad,
", RightPad is ",
RightPad);
return false;
}
}
}
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
const index_t LeftPad = kargs.input_left_pads[i];
const index_t RightPad = kargs.input_right_pads[i];
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
{
LogInfo("For Filter1x1Pad0 specialization, all spatial dimensions must be 1 "
"and padding must be 0. Now for dimension ",
i,
": SpatialDim is ",
SpatialDim,
", LeftPad is ",
LeftPad,
", RightPad is ",
RightPad);
return false;
}
}
}
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
{
if(ConvC != 1)
{
LogInfo("For Filter3x3 specialization, ConvC must be 1. Now ConvC is ", ConvC);
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
if(filter_spatial_dim != I3)
{
LogInfo("For Filter3x3 specialization, all spatial dimensions of the filter "
"must be 3. Now for dimension ",
i,
", filter_spatial_dim is ",
filter_spatial_dim);
return false;
}
}
}
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
LogInfo("ExplicitGemm is only supported for Filter1x1Stride1Pad0 specialization.");
return false;
}
namespace ctc = tensor_layout::convolution;
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
std::is_same_v<InLayout, ctc::NDHWGC>)
{
// Check access per C
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
LogInfo("Conv C is not a multiple of vector load size for input! ConvC: ",
ConvC,
", VectorSizeB: ",
GroupedConvTraitsType_::VectorSizeB);
return false;
}
}
else
{
LogInfo("Not supported input layout! Now InLayout is ", InLayout::name);
return false;
}
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
{
if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
{
LogInfo("Conv C is not a multiple of vector load size for weight! ConvC: ",
ConvC,
", VectorSizeC: ",
GroupedConvTraitsType_::VectorSizeC);
return false;
}
}
else
{
LogInfo("Not supported weight layout! Now WeiLayout is ", WeiLayout::name);
return false;
}
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
std::is_same_v<OutLayout, ctc::NHWGK> ||
std::is_same_v<OutLayout, ctc::NDHWGK>)
{
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
{
LogInfo("Conv K is not a multiple of vector load size for output! ConvK: ",
ConvK,
", VectorSizeA: ",
GroupedConvTraitsType_::VectorSizeA);
return false;
}
}
else
{
LogInfo("Not supported output layout! Now OutLayout is ", OutLayout::name);
return false;
}
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
{
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
{
LogInfo("Number of groups must be divisible by NumGroupsToMerge! ConvG: ",
ConvG,
", NumGroupsToMerge: ",
GroupedConvTraitsType_::NumGroupsToMerge);
return false;
}
}
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto
MakeCBlockWindow(WeiDataType* c_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
const auto& c_tensor_view =
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, kargs.c_grid_desc_m_n);
const auto& c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
return make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
CK_TILE_DEVICE static auto
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
"Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
"Not supported!");
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
"Not supported!");
return make_tensor_view<address_space_enum::global>(
static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
},
number<NumDTensor>{});
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
},
number<NumDTensor>{});
return generate_tuple(
[&](auto i) {
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
},
number<NumDTensor>{});
}
CK_TILE_DEVICE static auto
MakeBBlockWindow(const InDataType* b_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_n,
const index_t block_idx_k)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
const auto& b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_k_n);
const auto& b_pad_view =
pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{block_idx_k, block_idx_n});
}
CK_TILE_DEVICE static auto
MakeABlockWindow(const OutDataType* a_ptr,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t block_idx_m,
const index_t block_idx_k)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_k_m);
const auto& a_pad_view =
pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
number<TilePartitioner::MPerBlock>{}),
sequence<true, true>{});
return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
{block_idx_k, block_idx_m});
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs Grouped Convolution Backward Weight kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
const InDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
WeiDataType* c_ptr,
void* smem_ptr_0,
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
const index_t num_loop,
const index_t block_idx_m,
const index_t block_idx_n,
const index_t block_idx_k)
{
// Create block windows using helper methods
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline with k_batch dispatching
if(kargs.k_batch == 1)
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
else
{
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
}
}
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
static_assert(NumDTensor == 0, "Not supported!");
using ExplicitBatchedGemmKernel =
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
{{kargs.out_ptr},
{kargs.in_ptr},
{},
kargs.wei_ptr,
kargs.GemmM,
kargs.GemmN,
kargs.GemmK,
{kargs.GemmM * kargs.GemmBatch},
{kargs.GemmN * kargs.GemmBatch},
{},
kargs.GemmN,
kargs.k_batch},
kargs.GemmM,
kargs.GemmN,
kargs.GemmM * kargs.GemmN,
kargs.GemmBatch};
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
}
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
CallExplicitGemm(kargs);
}
else
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
const index_t num_loop = amd_wave_read_first_lane(ck_tile::integer_divide_ceil(
kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
const index_t i_k =
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
// options
// conv_bwd_weight = Out * In = Weight
const OutDataType* a_ptr =
static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
__shared__ char smem_ptr[GetSmemSize()];
// Disable Async for other archs than gfx950
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
#endif
}
else
{
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
}
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,208 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
// UniversalGemm Policy
struct GroupedConvUniversalPipelineAgBgCrPolicy
: public UniversalGemmBasePolicy<GroupedConvUniversalPipelineAgBgCrPolicy>
{
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ADataType = OverrideADataType;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
if constexpr(is_a_load_tr<Problem>)
{
// TODO: better lds descriptor for performance
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
make_tuple(number<MPerBlock>{}, number<1>{}),
number<MPerBlock>{},
number<1>{});
return a_lds_block_desc_0;
}
else
{
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto MLdsLayer =
max(MinLdsLayer,
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
number<MPerBlock / MLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer * RowMul>{},
number<KPerBlock / KPack * MLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
}
/**
* @brief Create LDS block descriptor for B tensor.
*
* @tparam Problem Gemm pipeline problem.
* @return B tensor LDS block descriptor.
*/
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better lds descriptor for performance
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
make_tuple(number<NPerBlock>{}, number<1>{}),
number<NPerBlock>{},
number<1>{});
return b_lds_block_desc_0;
}
else
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto NLdsLayer =
max(MinLdsLayer,
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(
BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer * RowMul>{},
BK0 * number<NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
namespace ck_tile {
enum struct ConvolutionSpecialization
{
Default,
Filter1x1Pad0,
Filter1x1Stride1Pad0,
Filter3x3,
};
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization& s)
{
switch(s)
{
case ConvolutionSpecialization::Default: return "Default";
case ConvolutionSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionSpecialization::Filter3x3: return "Filter3x3";
default: return "Unrecognized specialization!";
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,261 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
enum class GroupedConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
/// @brief The Grouped Conv kernel host arguments.
///
/// @par Overview
/// This structure is passed to Grouped Convolution Kernels when creating kernel
/// arguments object. It contain all necessary information required to
/// build proper kernel argument and launch kernel on GPU.
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
struct GroupedConvHostArgs : public conv::ConvParam
{
CK_TILE_HOST GroupedConvHostArgs() = delete;
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
InPtr in_ptr_,
WeiPtr wei_ptr_,
const std::vector<const void*> ds_ptr_,
OutPtr out_ptr_,
index_t k_batch_,
CDElementwise elfunc_ = CDElementwise{})
: conv::ConvParam(conv_param),
in_ptr(in_ptr_),
wei_ptr(wei_ptr_),
ds_ptr(ds_ptr_),
out_ptr(out_ptr_),
k_batch(k_batch_),
elfunc(elfunc_)
{
}
InPtr in_ptr;
WeiPtr wei_ptr;
const std::vector<const void*> ds_ptr;
OutPtr out_ptr;
index_t k_batch;
const CDElementwise elfunc;
};
using PassThrough = ck_tile::element_wise::PassThrough;
template <typename CDElementwise = PassThrough>
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
using GroupedConvBwdWeightHostArgs =
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
using GroupedConvBwdDataHostArgs =
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
template <index_t NDimSpatial_,
ConvolutionSpecialization ConvSpecialization_,
typename InLayout_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
index_t NumGroupsToMerge_ = 1,
bool EnableSplitImage_ = false,
bool ExplicitGemm_ = false>
struct GroupedConvTraits
{
private:
static constexpr auto generate_implicit_gemm_layout()
{
return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
number<DsLayout_::size()>{});
}
public:
// Fixed values for Implicit GEMM
struct FixedGemmParams
{
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool TransposeC = false;
static constexpr bool FixedVectorSize = true;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
};
// Compile time parameters
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr bool ExplicitGemm = ExplicitGemm_;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
// Forward Gemm Layouts
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Data Gemm Layouts
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Weight Gemm Layouts
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
template <GroupedConvDirection Direction>
struct GemmLayouts
{
static_assert(false, "Unsupported direction.");
};
template <>
struct GemmLayouts<GroupedConvDirection::FORWARD>
{
using AsLayout = AsLayoutFwd;
using BsLayout = BsLayoutFwd;
using CLayout = CLayoutFwd;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_DATA>
{
using AsLayout = AsLayoutBwdData;
using BsLayout = BsLayoutBwdData;
using CLayout = CLayoutBwdData;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_WEIGHT>
{
using AsLayout = AsLayoutBwdWeight;
using BsLayout = BsLayoutBwdWeight;
using CLayout = CLayoutBwdWeight;
};
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
true,
true,
AsLayoutBwdData,
BsLayoutBwdData,
CLayoutBwdData,
NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
true,
true,
AsLayoutBwdWeight,
BsLayoutBwdWeight,
CLayoutBwdWeight,
NumWaveGroups>;
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
};
/// @brief Helper struct for split-image piece information
///
/// @par Overview
/// Stores metadata for a single spatial piece in split-image convolution.
/// Used to track block ranges and spatial coordinates for each piece.
struct SplitImagePieceInfo
{
ck_tile::index_t block_start, block_end; ///< GPU block range for this piece
ck_tile::index_t d_start, h_start, w_start; ///< Spatial start coordinates (output space)
ck_tile::index_t d_size, h_size, w_size; ///< Spatial dimensions of this piece
};
/// @brief Calculate piece information for split-image convolution
///
/// @par Overview
/// Computes spatial coordinates, dimensions, and GPU block range for a single
/// piece in split-image convolution. Handles edge pieces that may have different
/// sizes due to non-uniform division.
///
/// @tparam TilePartitioner Type providing MPerBlock and NPerBlock constants
///
/// @param piece_idx Index of the piece to calculate (0-based)
/// @param num_d_pieces Number of pieces in D dimension
/// @param num_h_pieces Number of pieces in H dimension
/// @param num_w_pieces Number of pieces in W dimension
/// @param base_piece_d Base size of each D piece (may differ for last piece)
/// @param base_piece_h Base size of each H piece (may differ for last piece)
/// @param base_piece_w Base size of each W piece (may differ for last piece)
/// @param total_d Total D dimension size (output space)
/// @param total_h Total H dimension size (output space)
/// @param total_w Total W dimension size (output space)
/// @param N Batch size
/// @param K Output channels
/// @param total_blocks Accumulated block count from previous pieces
///
/// @return SplitImagePieceInfo containing all metadata for this piece
template <typename TilePartitioner>
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx,
ck_tile::index_t num_d_pieces,
ck_tile::index_t num_h_pieces,
ck_tile::index_t num_w_pieces,
ck_tile::index_t base_piece_d,
ck_tile::index_t base_piece_h,
ck_tile::index_t base_piece_w,
ck_tile::index_t total_d,
ck_tile::index_t total_h,
ck_tile::index_t total_w,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t total_blocks)
{
// Unflatten piece index into 3D coordinates (W-major, then H, then D)
const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
// Calculate spatial start positions
const ck_tile::index_t w_start = w_idx * base_piece_w;
const ck_tile::index_t h_start = h_idx * base_piece_h;
const ck_tile::index_t d_start = d_idx * base_piece_d;
// Calculate piece sizes (last piece may be larger to cover remainder)
const ck_tile::index_t w_size =
(w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
const ck_tile::index_t h_size =
(h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
const ck_tile::index_t d_size =
(d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
// Calculate GEMM dimensions for this piece
const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
const ck_tile::index_t piece_gemm_n = K;
// Calculate GPU grid size for this piece
const ck_tile::index_t piece_grid =
((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
return {
total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
}
} // namespace ck_tile

View File

@@ -0,0 +1,81 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <numeric>
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/host/kernel_launch.hpp"
namespace ck_tile {
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
CK_TILE_HOST index_t get_max_occupancy_for_kernel()
{
constexpr int dynamic_smem_size = 0;
constexpr int min_blocks_per_cu = 1;
const auto kernel_ptr = kentry<min_blocks_per_cu, KernelImpl, KernelArgs>;
int max_occupancy = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy, kernel_ptr, BlockSize, dynamic_smem_size));
return static_cast<index_t>(max_occupancy);
}
CK_TILE_HOST index_t get_best_occupancy_k_batch_value(index_t max_occupancy, index_t grid_size)
{
static const index_t num_cus = get_num_cus();
const index_t max_capacity = max_occupancy * num_cus;
index_t k_batch = 1;
const auto optimal_split = static_cast<index_t>(std::floor((1.0 * max_capacity) / grid_size));
if(optimal_split > 1)
{
k_batch = optimal_split;
}
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
<< max_occupancy << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
}
return k_batch;
}
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
struct ActiveWorkgroupsPerCU
{
CK_TILE_HOST ActiveWorkgroupsPerCU()
{
max_occupancy_ = get_max_occupancy_for_kernel<BlockSize, KernelArgs, KernelImpl>();
}
index_t max_occupancy_{1};
};
template <index_t BlockSize, typename KernelImpl, typename TilePartitioner, typename KernelArgs>
CK_TILE_HOST index_t calculate_optimal_k_batch(const KernelArgs& kargs)
{
static ActiveWorkgroupsPerCU<BlockSize, KernelArgs, KernelImpl> active_workgroups_per_cu;
const auto grid_size = TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN) * kargs.GemmBatch;
auto optimal_k_batch =
get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size);
const auto max_allowed_k_batch = kargs.GemmK;
optimal_k_batch = std::min(optimal_k_batch, max_allowed_k_batch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << optimal_k_batch << std::endl;
}
return optimal_k_batch;
}
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff