[CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136)

* Merge fwd conv groups in CK Tile.

* Fix building CK fwd convs.

* Add number of merged groups to conv fwd kernel name.

* Get number of merged groups from conv config.

* Rename GemmConfig to ConvConfig.

* Clean-up TODOs.

* Check that number of conv groups must be divisible by the number of merged groups.

* Improve error handling in the conv fwd example.

* Fix clang-format.

* Fix group offsets.

* Fix merge problem.

* Address feedback from code review.

* Fix clang-formatting.
This commit is contained in:
Ville Pietilä
2025-12-02 15:23:32 +02:00
committed by GitHub
parent 2d3020e5b0
commit 66832861ad
4 changed files with 111 additions and 58 deletions

View File

@@ -470,10 +470,10 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -701,11 +701,11 @@ struct TransformConvFwdToGemm
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_;
IndexType HiStride_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -960,12 +960,12 @@ struct TransformConvFwdToGemm
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
IndexType HiStride_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -1289,9 +1289,9 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeBDescriptor_N_K() const
{
IndexType CStrideTensorB_ = 1;
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
IndexType CStrideTensorB_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
{
@@ -1356,10 +1356,10 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
@@ -1417,11 +1417,11 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_;
IndexType HoStride_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
@@ -1482,12 +1482,12 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
IndexType HoStride_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)