[CK_TILE] Add conv bwd weight two stage support (#2855)

* resolved conflicts

* add conv bwd weight twostage

* fix one file

* fixes after review

* fixes

* fixes

* Fix

---------

Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
jakpiase
2025-09-22 15:31:25 +02:00
committed by GitHub
parent 4363a82bd6
commit 624c46866e
16 changed files with 864 additions and 361 deletions

View File

@@ -24,7 +24,10 @@ struct GroupedConvBwdDataKernelArgs
using ConvToGemmTransformer =
TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization>;
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
static constexpr auto I0 = number<0>();
@@ -468,6 +471,10 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardDataKernel
{
// Todo: Enable Vector Load Size > 1
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
GroupedConvTraitsType_::VectorSizeB == 1);
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
@@ -509,10 +516,13 @@ struct GroupedConvolutionBackwardDataKernel
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
"Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
"Not supported A GEMM layout!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>,
"Not supported B GEMM layout!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
// TODO: Change to and enable vector load
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
// "Not supported A GEMM layout!");
// static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>,
// "Not supported B GEMM layout!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
"Not supported C GEMM layout!");
@@ -548,7 +558,7 @@ struct GroupedConvolutionBackwardDataKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
{
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
{
@@ -625,7 +635,7 @@ struct GroupedConvolutionBackwardDataKernel
std::is_same_v<InLayout, ctc::NDHWGC>)
{
// Check access per C
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
return false;
@@ -637,13 +647,12 @@ struct GroupedConvolutionBackwardDataKernel
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
{
if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0)
if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
return false;
@@ -655,12 +664,11 @@ struct GroupedConvolutionBackwardDataKernel
return false;
}
// check vector access of E
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
std::is_same_v<OutLayout, ctc::NHWGK> ||
std::is_same_v<OutLayout, ctc::NDHWGK>)
{
if(ConvK % GemmPipeline::GetVectorSizeA() != 0)
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
{
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
return false;
@@ -957,7 +965,7 @@ struct GroupedConvolutionBackwardDataKernel
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
@@ -975,7 +983,7 @@ struct GroupedConvolutionBackwardDataKernel
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);

View File

@@ -23,7 +23,10 @@ struct GroupedConvBwdWeightKernelArgs
using ConvToGemmTransformer =
TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization>;
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -335,6 +338,10 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardWeightKernel
{
// Todo: Enable Vector Load Size > 1
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
GroupedConvTraitsType_::VectorSizeB == 1);
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
@@ -355,11 +362,10 @@ struct GroupedConvolutionBackwardWeightKernel
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using OutDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using InDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using WeiDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using GroupedConvBwdWeightKernelArgsSpecialized =
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
@@ -376,6 +382,10 @@ struct GroupedConvolutionBackwardWeightKernel
"Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
// TODO: Change to and enable vector load
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not
// supported!"); static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not
// supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -453,8 +463,8 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
{
if(kargs.k_batch != 1)
@@ -525,7 +535,7 @@ struct GroupedConvolutionBackwardWeightKernel
std::is_same_v<InLayout, ctc::NDHWGC>)
{
// Check access per C
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
return false;
@@ -537,13 +547,11 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
{
if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0)
if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
return false;
@@ -555,12 +563,11 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
// check vector access of E
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
std::is_same_v<OutLayout, ctc::NHWGK> ||
std::is_same_v<OutLayout, ctc::NDHWGK>)
{
if(ConvK % GemmPipeline::GetVectorSizeA() != 0)
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
{
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
return false;
@@ -596,9 +603,8 @@ struct GroupedConvolutionBackwardWeightKernel
}();
const auto& c_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
kargs.c_grid_desc_m_n); // B: in
return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
kargs.c_grid_desc_m_n);
}();
const auto& ds_tensor_view = generate_tuple(
@@ -607,11 +613,11 @@ struct GroupedConvolutionBackwardWeightKernel
"Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
"Not supported!");
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
"Not supported!");
return make_tensor_view<address_space_enum::global>(
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
},
number<NumDTensor>{});
@@ -829,8 +835,8 @@ struct GroupedConvolutionBackwardWeightKernel
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(a_ptr,
b_ptr,
@@ -848,8 +854,8 @@ struct GroupedConvolutionBackwardWeightKernel
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
{
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);

View File

@@ -24,6 +24,9 @@ struct GroupedConvFwdKernelArgs
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC,
true>; // Split N enabled
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
@@ -467,7 +470,7 @@ struct GroupedConvolutionForwardKernel
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
{
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
{
@@ -550,7 +553,7 @@ struct GroupedConvolutionForwardKernel
std::is_same_v<InLayout, ctc::NDHWGC>)
{
// Check access per C
if(ConvC % GemmPipeline::GetVectorSizeA() != 0)
if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
return false;
@@ -568,7 +571,7 @@ struct GroupedConvolutionForwardKernel
std::is_same_v<WeiLayout, ctc::GKYXC> ||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
{
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
return false;
@@ -585,7 +588,7 @@ struct GroupedConvolutionForwardKernel
std::is_same_v<OutLayout, ctc::NHWGK> ||
std::is_same_v<OutLayout, ctc::NDHWGK>)
{
if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0)
if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
{
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
return false;
@@ -858,7 +861,7 @@ struct GroupedConvolutionForwardKernel
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(
@@ -868,7 +871,7 @@ struct GroupedConvolutionForwardKernel
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);

View File

@@ -49,7 +49,10 @@ template <index_t NDimSpatial_,
typename InLayout_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_>
typename OutLayout_,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1>
struct GroupedConvTraits
{
private:
@@ -67,14 +70,38 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
using GroupedConvImplicitGemmTraits = TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
static constexpr index_t NumDTensor = DsLayout::size();
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdData =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
// TODO: Change to and enable vector load
// ck_tile::tensor_layout::gemm::RowMajor,
// ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdWeight =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
// TODO: Change to and enable vector load
// ck_tile::tensor_layout::gemm::ColumnMajor,
// ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
static constexpr index_t NumDTensor = DsLayout::size();
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
};
} // namespace ck_tile

View File

@@ -10,6 +10,9 @@ namespace ck_tile {
template <index_t NDimSpatial,
ConvolutionSpecialization ConvolutionSpecialization,
index_t VectorSizeA,
index_t VectorSizeB,
index_t VectorSizeC,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
@@ -442,14 +445,17 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_),
make_tuple(NStride, WoStride, KStride));
make_tuple(NStride, WoStride, KStride),
number<VectorSizeA>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
CK_TILE_HOST auto make_wei_grid_desc() const
{
// GKXC
return make_naive_tensor_descriptor_packed(make_tuple(K_, X_, C_));
return make_naive_tensor_descriptor(
make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number<VectorSizeB>{}, I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -462,7 +468,9 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride));
make_tuple(NStride, WiStride, CStride),
number<VectorSizeC>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -477,7 +485,9 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_),
make_tuple(NStride, HoStride, WoStride, KStride));
make_tuple(NStride, HoStride, WoStride, KStride),
number<VectorSizeA>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -491,14 +501,19 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStride, HiStride, WiStride, CStride));
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
CK_TILE_HOST auto make_wei_grid_desc() const
{
// GKYXC
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_),
make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1),
number<VectorSizeC>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -514,7 +529,9 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, K_),
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
make_tuple(NStride, DoStride, HoStride, WoStride, KStride),
number<VectorSizeA>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -529,14 +546,20 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
CK_TILE_HOST auto make_wei_grid_desc() const
{
// GKZYXC
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
return make_naive_tensor_descriptor(
make_tuple(K_, Z_, Y_, X_, C_),
make_tuple(C_ * X_ * Y_ * Z_, C_ * X_ * Y_, C_ * X_, C_, I1),
number<VectorSizeC>{},
I1);
}
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
// properties

View File

@@ -10,6 +10,9 @@ namespace ck_tile {
template <index_t NDimSpatial,
ConvolutionSpecialization ConvolutionSpecialization,
index_t VectorSizeA,
index_t VectorSizeB,
index_t VectorSizeC,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
@@ -420,7 +423,9 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -433,7 +438,9 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride));
make_tuple(NStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -444,7 +451,8 @@ struct TransformConvBwdWeightToGemm
constexpr auto CXStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride));
return make_naive_tensor_descriptor(
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -457,7 +465,9 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -471,7 +481,9 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStride, HiStride, WiStride, CStride));
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -482,8 +494,8 @@ struct TransformConvBwdWeightToGemm
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
make_tuple(KStride, CStride));
return make_naive_tensor_descriptor(
make_tuple(K_, Y_ * X_ * C_), make_tuple(KStride, CStride), number<VectorSizeC>{}, I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -496,7 +508,9 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -511,7 +525,9 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -523,7 +539,9 @@ struct TransformConvBwdWeightToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
make_tuple(KStride, CStride));
make_tuple(KStride, CStride),
number<VectorSizeC>{},
I1);
}
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as

View File

@@ -10,6 +10,9 @@ namespace ck_tile {
template <index_t NDimSpatial,
ConvolutionSpecialization ConvSpecialization,
index_t VectorSizeA,
index_t VectorSizeB,
index_t VectorSizeC,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
@@ -446,7 +449,9 @@ struct TransformConvFwdToGemm
{
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
@@ -458,7 +463,9 @@ struct TransformConvFwdToGemm
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
@@ -473,8 +480,11 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_));
const auto in_n_wi_c_desc =
make_naive_tensor_descriptor(make_tuple(N_, Wi_),
make_tuple(NStrideTensorA_, WiStride_),
number<VectorSizeA>{},
I1);
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
@@ -502,7 +512,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_));
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
@@ -535,7 +547,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
@@ -556,7 +570,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
@@ -581,7 +597,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
@@ -611,7 +629,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
@@ -661,7 +681,9 @@ struct TransformConvFwdToGemm
{
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
@@ -675,7 +697,9 @@ struct TransformConvFwdToGemm
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
@@ -689,8 +713,11 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_));
const auto in_n_hi_wi_c_desc =
make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
@@ -721,7 +748,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_));
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
@@ -757,7 +786,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
@@ -780,7 +811,9 @@ struct TransformConvFwdToGemm
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
@@ -808,7 +841,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
@@ -843,7 +878,9 @@ struct TransformConvFwdToGemm
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
@@ -904,7 +941,9 @@ struct TransformConvFwdToGemm
{
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Do_, Ho_, Wo_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
@@ -922,7 +961,9 @@ struct TransformConvFwdToGemm
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
CStrideTensorA_),
number<VectorSizeA>{},
I1);
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
@@ -939,7 +980,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_));
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
@@ -975,7 +1018,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_));
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
@@ -1022,7 +1067,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
@@ -1052,7 +1099,9 @@ struct TransformConvFwdToGemm
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
@@ -1090,7 +1139,9 @@ struct TransformConvFwdToGemm
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
@@ -1138,7 +1189,9 @@ struct TransformConvFwdToGemm
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
CStrideTensorA_),
number<VectorSizeA>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
@@ -1217,14 +1270,19 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{}));
return make_naive_tensor_descriptor(make_tuple(K_, FilterSizeNumType{}),
make_tuple(FilterSizeNumType{}, I1),
number<VectorSizeB>{},
I1);
}
else
{
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
number<VectorSizeB>{},
I1);
return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
@@ -1237,13 +1295,18 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_));
return make_naive_tensor_descriptor(make_tuple(K_, ZYX_ * C_),
make_tuple(ZYX_ * C_, I1),
number<VectorSizeB>{},
I1);
}
else
{
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
number<VectorSizeB>{},
I1);
return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
@@ -1270,14 +1333,18 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
make_tuple(WoStride_, KStrideTensorC_),
number<VectorSizeC>{},
I1);
}
else
{
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_),
number<VectorSizeC>{},
I1);
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
@@ -1328,7 +1395,9 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
make_tuple(WoStride_, KStrideTensorC_),
number<VectorSizeC>{},
I1);
}
else
{
@@ -1339,7 +1408,9 @@ struct TransformConvFwdToGemm
WoStride_,
GStrideTensorC_,
KStrideTensorC_,
GStrideTensorC_));
GStrideTensorC_),
number<VectorSizeC>{},
I1);
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
@@ -1390,7 +1461,9 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
make_tuple(WoStride_, KStrideTensorC_));
make_tuple(WoStride_, KStrideTensorC_),
number<VectorSizeC>{},
I1);
}
else
{
@@ -1402,7 +1475,9 @@ struct TransformConvFwdToGemm
WoStride_,
GStrideTensorC_,
KStrideTensorC_,
GStrideTensorC_));
GStrideTensorC_),
number<VectorSizeC>{},
I1);
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,