mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
[CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136)
* Merge fwd conv groups in CK Tile.
* Fix building CK fwd convs.
* Add number of merged groups to conv fwd kernel name.
* Get number of merged groups from conv config.
* Rename GemmConfig to ConvConfig.
* Clean-up TODOs.
* Check that number of conv groups must be divisible by the number of merged groups.
* Improve error handling in the conv fwd example.
* Fix clang-format.
* Fix group offsets.
* Fix merge problem.
* Address feedback from code review.
* Fix clang-formatting.
[ROCm/composable_kernel commit: 66832861ad]
This commit is contained in:
@@ -470,10 +470,10 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
{
|
||||
IndexType NStrideTensorA_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -701,11 +701,11 @@ struct TransformConvFwdToGemm
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
|
||||
{
|
||||
IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_;
|
||||
IndexType HiStride_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -960,12 +960,12 @@ struct TransformConvFwdToGemm
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
|
||||
{
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
|
||||
IndexType HiStride_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -1289,9 +1289,9 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeBDescriptor_N_K() const
|
||||
{
|
||||
IndexType CStrideTensorB_ = 1;
|
||||
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
|
||||
IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
|
||||
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
|
||||
IndexType CStrideTensorB_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
@@ -1356,10 +1356,10 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
@@ -1417,11 +1417,11 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_;
|
||||
IndexType HoStride_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
@@ -1482,12 +1482,12 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
|
||||
IndexType HoStride_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
|
||||
Reference in New Issue
Block a user