mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
[rocm-libraries] ROCm/rocm-libraries#6838 (commit ff7a665)
[CK_TILE] Add depthwise conv2d forward kernel (FP16/FP32) (#6838) ## Motivation CK currently has no kernel optimized for depthwise convolution (G=C_in=C_out, C=K=1 per group) and existing generic paths perform poorly for this workload. This PR adds a dedicated depthwise conv forward kernel in CK Tile. ## Technical Details Adds a dedicated depthwise conv2d forward op to CK Tile that performs direct convolution rather than falling back to the generic GEMM path. The kernel is templatized by filter size, stride, and data type, and compiled into ~60 instances covering common configurations (kernel 3/5/7/9, stride 1/2, FP16/FP32). Supports both CDNA (gfx942/gfx950) and RDNA (gfx1100/gfx1200) architectures. ## Test Plan - [x] Correctness and performance validated on gfx942, gfx950, and gfx1100, with ckProfiler `grouped_conv_fwd` as baseline. - [ ] MI300A (gfx942) and gfx1200 validation. ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-1137
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fe2e29fa68
commit
945849b0f5
@@ -24,18 +24,51 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Deferred type resolution: partial specialization stubs out types for the inactive path,
|
||||
// preventing member access on void template parameters at compile time.
|
||||
namespace detail {
|
||||
|
||||
template <typename ConvTraits, bool IsDepthwise = ConvTraits::IsDepthwise>
|
||||
struct ConvFwdGemmDescTypes;
|
||||
|
||||
template <typename T>
|
||||
struct ConvFwdGemmDescTypes<T, false>
|
||||
{
|
||||
using Transformer = TransformConvFwdToGemm<T::NDimSpatial,
|
||||
T::ConvSpecialization,
|
||||
T::VectorSizeA,
|
||||
T::VectorSizeB,
|
||||
T::VectorSizeC,
|
||||
T::NumGroupsToMerge,
|
||||
true>;
|
||||
using AGridDescMK = remove_cvref_t<
|
||||
decltype(Transformer{}.template MakeADescriptor_M_K<typename T::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<
|
||||
decltype(Transformer{}.template MakeBDescriptor_N_K<typename T::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<
|
||||
decltype(Transformer{}.template MakeCDescriptor_M_N<typename T::OutLayout>())>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ConvFwdGemmDescTypes<T, true>
|
||||
{
|
||||
using Transformer = int;
|
||||
using AGridDescMK = int;
|
||||
using BGridDescNK = int;
|
||||
using CGridDescMN = int;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
template <typename GroupedConvTraitsType_, typename CDElementwise_>
|
||||
struct GroupedConvFwdKernelArgs
|
||||
{
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
GroupedConvTraitsType_::NumGroupsToMerge,
|
||||
true>; // Split N enabled
|
||||
static constexpr bool IsDepthwise_ = GroupedConvTraitsType_::IsDepthwise;
|
||||
|
||||
using GemmDescTypes_ = detail::ConvFwdGemmDescTypes<GroupedConvTraitsType_>;
|
||||
using ConvToGemmFwdTransformer = typename GemmDescTypes_::Transformer;
|
||||
|
||||
using CDElementwise = CDElementwise_;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
@@ -342,15 +375,9 @@ struct GroupedConvFwdKernelArgs
|
||||
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
|
||||
}
|
||||
}
|
||||
using AGridDescMK = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
|
||||
using AGridDescMK = typename GemmDescTypes_::AGridDescMK;
|
||||
using BGridDescNK = typename GemmDescTypes_::BGridDescNK;
|
||||
using CGridDescMN = typename GemmDescTypes_::CGridDescMN;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> in_g_n_c_wis_lengths;
|
||||
@@ -425,6 +452,54 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
|
||||
SplitImageInfo split_image; // Nested structure with common + per-piece data
|
||||
|
||||
// Depthwise-only: NGCHW/GKCYX/NGKHW packed strides
|
||||
static constexpr index_t kStrideDims = NonSpatialDims + GroupedConvTraitsType_::NDimSpatial;
|
||||
array<index_t, kStrideDims> dw_in_strides = {};
|
||||
array<index_t, kStrideDims> dw_wei_strides = {};
|
||||
array<index_t, kStrideDims> dw_out_strides = {};
|
||||
|
||||
template <typename Dummy = void,
|
||||
std::enable_if_t<IsDepthwise_ && std::is_void_v<Dummy>, bool> = true>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
|
||||
: elfunc(args.elfunc)
|
||||
{
|
||||
static_assert(GroupedConvTraitsType_::NDimSpatial == 2,
|
||||
"Depthwise only supports 2D convolution");
|
||||
const index_t G = static_cast<index_t>(args.G_);
|
||||
const index_t N = static_cast<index_t>(args.N_);
|
||||
const index_t C = static_cast<index_t>(args.C_);
|
||||
const index_t Hi = static_cast<index_t>(args.input_spatial_lengths_[0]);
|
||||
const index_t Wi = static_cast<index_t>(args.input_spatial_lengths_[1]);
|
||||
const index_t K = static_cast<index_t>(args.K_);
|
||||
const index_t Y = static_cast<index_t>(args.filter_spatial_lengths_[0]);
|
||||
const index_t X = static_cast<index_t>(args.filter_spatial_lengths_[1]);
|
||||
const index_t Ho = static_cast<index_t>(args.output_spatial_lengths_[0]);
|
||||
const index_t Wo = static_cast<index_t>(args.output_spatial_lengths_[1]);
|
||||
|
||||
in_g_n_c_wis_lengths = {G, N, C, Hi, Wi};
|
||||
wei_g_k_c_xs_lengths = {G, K, C, Y, X};
|
||||
out_g_n_k_wos_lengths = {G, N, K, Ho, Wo};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = 1;
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
out_ptr = args.out_ptr;
|
||||
GemmBatch = G;
|
||||
|
||||
dw_in_strides = {C * Hi * Wi, G * C * Hi * Wi, Hi * Wi, Wi, 1};
|
||||
dw_wei_strides = {K * C * Y * X, C * Y * X, Y * X, X, 1};
|
||||
dw_out_strides = {K * Ho * Wo, G * K * Ho * Wo, Ho * Wo, Wo, 1};
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Forward kernel template.
|
||||
@@ -436,14 +511,14 @@ struct GroupedConvFwdKernelArgs
|
||||
///
|
||||
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
|
||||
/// function call operator" which determines the work scope of each workgroup.
|
||||
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// @li @b Pipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// This is the place where each workgroup is loading data from global memory and
|
||||
/// carrying out dot products.
|
||||
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
|
||||
/// responsible for storing results to global memory. This is also the place where
|
||||
/// any additional operator fusion may take place.
|
||||
///
|
||||
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
|
||||
/// Additionally both @ref Pipeline_ "Pipeline" and @ref EpiloguePipeline_
|
||||
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
|
||||
/// internal details of those functional parts. You can think of it like both gemm and
|
||||
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
|
||||
@@ -456,49 +531,51 @@ struct GroupedConvFwdKernelArgs
|
||||
/// output data tile to be calculated. It determines the
|
||||
/// workgroup to data relationship (or in other words - which
|
||||
/// data would be processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
/// @tparam Pipeline_ The type of class which provides the core part of matrix
|
||||
/// multiplication. This class should provide implementation of
|
||||
/// data loading from global memory and performing block-wise
|
||||
/// matrix multiplication. You can think of it as a work done by
|
||||
/// single workgroup point of view.
|
||||
/// matrix multiplication. For depthwise convolution, this is
|
||||
/// DepthwiseConvFwdPipeline instead.
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// results calculated by @ref Pipeline_ "Pipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename Pipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static constexpr bool IsDepthwise = GroupedConvTraitsType_::IsDepthwise;
|
||||
using DwTraits = typename GroupedConvTraitsType_::DepthwiseTraits;
|
||||
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using GemmALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using GemmBLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using GemmCLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using GemmALayout = remove_cvref_t<typename Pipeline::ALayout>;
|
||||
using GemmBLayout = remove_cvref_t<typename Pipeline::BLayout>;
|
||||
using GemmCLayout = remove_cvref_t<typename Pipeline::CLayout>;
|
||||
|
||||
using InLayout = remove_cvref_t<typename GroupedConvTraitsType_::InLayout>;
|
||||
using WeiLayout = remove_cvref_t<typename GroupedConvTraitsType_::WeiLayout>;
|
||||
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType_::OutLayout>;
|
||||
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType_::DsLayout>;
|
||||
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline_::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t kBlockSize = Pipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using CDElementwise = typename EpiloguePipeline::CDElementwise;
|
||||
using InDataType = remove_cvref_t<typename Pipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename Pipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline_::DsDataType>;
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline_::ODataType>;
|
||||
using CDElementwise = typename EpiloguePipeline_::CDElementwise;
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized =
|
||||
GroupedConvFwdKernelArgs<GroupedConvTraitsType_, CDElementwise>;
|
||||
@@ -511,16 +588,25 @@ struct GroupedConvolutionForwardKernel
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I5 = number<5>();
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor> ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge > 1,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
|
||||
"Not supported!");
|
||||
static constexpr bool CheckGemmAsserts()
|
||||
{
|
||||
if constexpr(!IsDepthwise)
|
||||
{
|
||||
static_assert(Pipeline::kPadM && Pipeline::kPadN && Pipeline::kPadK, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor> ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge > 1,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported!");
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
|
||||
"Not supported!");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static_assert(CheckGemmAsserts());
|
||||
|
||||
// Helper struct for spatial coordinates
|
||||
struct SpatialCoords
|
||||
@@ -595,26 +681,49 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_forward",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
InLayout::name,
|
||||
WeiLayout::name,
|
||||
OutLayout::name,
|
||||
"gemm",
|
||||
GemmPipeline::GetName(),
|
||||
"epilogue",
|
||||
EpiloguePipeline::GetName(),
|
||||
getConvSpecializationString(ConvSpecialization),
|
||||
"MergedGroups",
|
||||
NumGroupsToMerge,
|
||||
"SplitImage",
|
||||
EnableSplitImage,
|
||||
"ExplicitGemm",
|
||||
GroupedConvTraitsType_::ExplicitGemm
|
||||
);
|
||||
// clang-format on
|
||||
if constexpr(IsDepthwise)
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_forward_depthwise",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
"bs", Pipeline::BlockSize,
|
||||
"th", Pipeline::TileOutH,
|
||||
"tw", Pipeline::TileOutW,
|
||||
"fh", Pipeline::FilterH,
|
||||
"fw", Pipeline::FilterW,
|
||||
"sh", Pipeline::StrideH,
|
||||
"sw", Pipeline::StrideW,
|
||||
"nb", Pipeline::NBatch,
|
||||
"sbh", Pipeline::SubTileH,
|
||||
"sbw", Pipeline::SubTileW,
|
||||
"iv", Pipeline::InVectorSize,
|
||||
"ov", Pipeline::OutVectorSize
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_forward",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
InLayout::name,
|
||||
WeiLayout::name,
|
||||
OutLayout::name,
|
||||
"gemm",
|
||||
Pipeline::GetName(),
|
||||
"epilogue",
|
||||
EpiloguePipeline::GetName(),
|
||||
getConvSpecializationString(ConvSpecialization),
|
||||
"MergedGroups",
|
||||
NumGroupsToMerge,
|
||||
"SplitImage",
|
||||
EnableSplitImage,
|
||||
"ExplicitGemm",
|
||||
GroupedConvTraitsType_::ExplicitGemm
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
|
||||
@@ -634,8 +743,19 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
return dim3(
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
|
||||
if constexpr(IsDepthwise)
|
||||
{
|
||||
const index_t G = kargs.in_g_n_c_wis_lengths[number<0>{}];
|
||||
const index_t N = kargs.in_g_n_c_wis_lengths[number<1>{}];
|
||||
const index_t num_batch_groups = integer_divide_ceil(N, DwTraits::NBatch);
|
||||
return dim3(G, num_batch_groups, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN),
|
||||
kargs.GemmBatch,
|
||||
kargs.n_splits);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
@@ -652,113 +772,305 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
if constexpr(IsDepthwise)
|
||||
{
|
||||
return Pipeline_::GetSmemSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
return max(Pipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsDepthwiseArgumentSupported(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
static constexpr index_t NBatch = DwTraits::NBatch;
|
||||
|
||||
// NBatch (batches processed per tile) must be a multiple of TilePerWave so that
|
||||
// each wave receives a whole number of batches with no remainder.
|
||||
if constexpr(NBatch % DwTraits::TilePerWave != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// Each sub-tile's input footprint in W (SubTileW * StrideW) must be aligned to
|
||||
// the internal vector load width, otherwise the vectorised load would straddle a
|
||||
// boundary and produce incorrect results.
|
||||
if constexpr(DwTraits::SubTileW * DwTraits::StrideW % DwTraits::InVectorSizeInternal != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// The kernel always pads the LDS tile to simplify boundary handling; a zero
|
||||
// PadW means there is no left padding to absorb and the tiling assumption breaks.
|
||||
if constexpr(DwTraits::PadW == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// The number of threads needed to load one LDS row (LdsTileW / InVectorSize) must
|
||||
// not exceed the block size; otherwise some rows would go unloaded.
|
||||
if constexpr(integer_divide_ceil(DwTraits::LdsTileW, DwTraits::InVectorSize) >
|
||||
DwTraits::BlockSize)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// The pipeline's shared memory requirement must fit within the hardware LDS limit.
|
||||
if constexpr(Pipeline_::GetSmemSize() > static_cast<index_t>(get_smem_capacity()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Depthwise conv requires K == C == 1 in the weight tensor (one filter per channel).
|
||||
if(kargs.wei_g_k_c_xs_lengths[number<1>{}] != 1 ||
|
||||
kargs.wei_g_k_c_xs_lengths[number<2>{}] != 1)
|
||||
return false;
|
||||
// Filter spatial dimensions must exactly match the compile-time tile specialisation.
|
||||
if(kargs.wei_g_k_c_xs_lengths[number<3>{}] != DwTraits::FilterH ||
|
||||
kargs.wei_g_k_c_xs_lengths[number<4>{}] != DwTraits::FilterW)
|
||||
return false;
|
||||
// Convolution strides must match the compile-time specialisation.
|
||||
if(kargs.conv_filter_strides[number<0>{}] != DwTraits::StrideH ||
|
||||
kargs.conv_filter_strides[number<1>{}] != DwTraits::StrideW)
|
||||
return false;
|
||||
// Dilations must match the compile-time specialisation.
|
||||
if(kargs.conv_filter_dilations[number<0>{}] != DwTraits::DilationH ||
|
||||
kargs.conv_filter_dilations[number<1>{}] != DwTraits::DilationW)
|
||||
return false;
|
||||
// Right padding is handled by boundary clamping; only left pad must match.
|
||||
if(kargs.input_left_pads[number<0>{}] != DwTraits::PadH ||
|
||||
kargs.input_left_pads[number<1>{}] != DwTraits::PadW)
|
||||
return false;
|
||||
// Batch count must be divisible by NBatch so work can be evenly partitioned across tiles.
|
||||
if(kargs.in_g_n_c_wis_lengths[number<1>{}] % NBatch != 0)
|
||||
return false;
|
||||
|
||||
// When multiple output tiles are processed per wave (TilePerWave > 1) the output
|
||||
// spatial dimensions must fit within a single tile; larger outputs need a different
|
||||
// specialisation.
|
||||
if constexpr(DwTraits::TilePerWave != 1)
|
||||
{
|
||||
if(kargs.out_g_n_k_wos_lengths[number<3>{}] > DwTraits::TileOutH ||
|
||||
kargs.out_g_n_k_wos_lengths[number<4>{}] > DwTraits::TileOutW)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
if constexpr(IsDepthwise)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return IsDepthwiseArgumentSupported(kargs);
|
||||
}
|
||||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
else // GEMM path
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
|
||||
if constexpr(Pipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
// check ConvolutionSpecialization
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t ConvStride = kargs.conv_filter_strides[i];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
|
||||
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
CK_TILE_ERROR(
|
||||
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check ConvolutionSpecialization
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t ConvStride = kargs.conv_filter_strides[i];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
// Check access for A tensor
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0 &&
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Conv C is not a multiple of vector load size for input image!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
else if(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
return false;
|
||||
if(ConvC != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow "
|
||||
"vector reads on group dimension!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
|
||||
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access for A tensor
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0 &&
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1)
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
CK_TILE_ERROR("Not supported input layout!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else if(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Not supported weight layout!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
// Try to read over G
|
||||
if(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0 ||
|
||||
ConvG % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"ConvG must be a multiple of NumGroupsToMerge to allow "
|
||||
"writing over G dimension");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"ConvK is not a multiple of vector store size for output image!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Not supported output layout!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
// currently group merging works only for C == 1 due to tensor transformation
|
||||
// limitations
|
||||
if(ConvC != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
@@ -779,108 +1091,10 @@ struct GroupedConvolutionForwardKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Not supported input layout!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Not supported weight layout!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
// Try to read over G
|
||||
if(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0 ||
|
||||
ConvG % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge to allow "
|
||||
"writing over G dimension");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"ConvK is not a multiple of vector store size for output image!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Not supported output layout!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
// currently group merging works only for C == 1 due to tensor transformation
|
||||
// limitations
|
||||
if(ConvC != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow "
|
||||
"vector reads on group dimension!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
} // else (GEMM path)
|
||||
}
|
||||
|
||||
template <typename ADescType>
|
||||
@@ -1068,8 +1282,8 @@ struct GroupedConvolutionForwardKernel
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
const auto& c_block_tile =
|
||||
Pipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(k_batch == 1)
|
||||
@@ -1101,7 +1315,7 @@ struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static_assert(NumDTensor == 0, "Not supported!");
|
||||
using ExplicitBatchedGemmKernel =
|
||||
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
BatchedGemmKernel<TilePartitioner, Pipeline, EpiloguePipeline>;
|
||||
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
|
||||
{{kargs.in_ptr},
|
||||
{kargs.wei_ptr},
|
||||
@@ -1122,9 +1336,72 @@ struct GroupedConvolutionForwardKernel
|
||||
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallDepthwiseConv(GroupedConvFwdKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
static_assert(IsDepthwise);
|
||||
static constexpr index_t NBatch = DwTraits::NBatch;
|
||||
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t batch_group = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
// dw_*_strides layout: [G, N, C, H, W]
|
||||
const long_index_t in_g_stride = kargs.dw_in_strides[number<0>{}];
|
||||
const long_index_t in_n_stride = kargs.dw_in_strides[number<1>{}];
|
||||
const long_index_t in_h_stride = kargs.dw_in_strides[number<3>{}];
|
||||
const long_index_t in_w_stride = kargs.dw_in_strides[number<4>{}];
|
||||
|
||||
const long_index_t wei_g_stride = kargs.dw_wei_strides[number<0>{}];
|
||||
const long_index_t wei_y_stride = kargs.dw_wei_strides[number<3>{}];
|
||||
const long_index_t wei_x_stride = kargs.dw_wei_strides[number<4>{}];
|
||||
|
||||
const long_index_t out_g_stride = kargs.dw_out_strides[number<0>{}];
|
||||
const long_index_t out_n_stride = kargs.dw_out_strides[number<1>{}];
|
||||
const long_index_t out_h_stride = kargs.dw_out_strides[number<3>{}];
|
||||
const long_index_t out_w_stride = kargs.dw_out_strides[number<4>{}];
|
||||
|
||||
const auto* p_in_base = static_cast<const InDataType*>(kargs.in_ptr) +
|
||||
static_cast<long_index_t>(g_idx) * in_g_stride +
|
||||
static_cast<long_index_t>(batch_group * NBatch) * in_n_stride;
|
||||
|
||||
const auto* p_wei_base = static_cast<const WeiDataType*>(kargs.wei_ptr) +
|
||||
static_cast<long_index_t>(g_idx) * wei_g_stride;
|
||||
|
||||
auto* p_out_base = static_cast<OutDataType*>(kargs.out_ptr) +
|
||||
static_cast<long_index_t>(g_idx) * out_g_stride +
|
||||
static_cast<long_index_t>(batch_group * NBatch) * out_n_stride;
|
||||
|
||||
const index_t Hi = kargs.in_g_n_c_wis_lengths[number<3>{}];
|
||||
const index_t Wi = kargs.in_g_n_c_wis_lengths[number<4>{}];
|
||||
const index_t Ho = kargs.out_g_n_k_wos_lengths[number<3>{}];
|
||||
const index_t Wo = kargs.out_g_n_k_wos_lengths[number<4>{}];
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline_{}(p_in_base,
|
||||
p_wei_base,
|
||||
p_out_base,
|
||||
smem,
|
||||
Hi,
|
||||
Wi,
|
||||
Ho,
|
||||
Wo,
|
||||
static_cast<index_t>(in_h_stride),
|
||||
static_cast<index_t>(in_w_stride),
|
||||
static_cast<index_t>(in_n_stride),
|
||||
static_cast<index_t>(wei_y_stride),
|
||||
static_cast<index_t>(wei_x_stride),
|
||||
static_cast<index_t>(out_h_stride),
|
||||
static_cast<index_t>(out_w_stride),
|
||||
static_cast<index_t>(out_n_stride));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
if constexpr(IsDepthwise)
|
||||
{
|
||||
CallDepthwiseConv(kargs);
|
||||
}
|
||||
else if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
{
|
||||
CallExplicitGemm(kargs);
|
||||
}
|
||||
@@ -1246,7 +1523,7 @@ struct GroupedConvolutionForwardKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
if constexpr(Pipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
RunGemm(a_ptr,
|
||||
|
||||
Reference in New Issue
Block a user