Merge commit '66832861ad78cc63584c32e5d231fd29a99c57b3' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-02 14:14:02 +00:00
parent 9c9a022007
commit aef67fef38
4 changed files with 111 additions and 58 deletions

View File

@@ -470,10 +470,10 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -701,11 +701,11 @@ struct TransformConvFwdToGemm
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_;
IndexType HiStride_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -960,12 +960,12 @@ struct TransformConvFwdToGemm
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
IndexType HiStride_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -1289,9 +1289,9 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeBDescriptor_N_K() const
{
IndexType CStrideTensorB_ = 1;
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
IndexType CStrideTensorB_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
{
@@ -1356,10 +1356,10 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
@@ -1417,11 +1417,11 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_;
IndexType HoStride_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
@@ -1482,12 +1482,12 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
IndexType HoStride_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)