mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136)
* Merge fwd conv groups in CK Tile. * Fix building CK fwd convs. * Add number of merged groups to conv fwd kernel name. * Get number of merged groups from conv config. * Rename GemmConfig to ConvConfig. * Clean-up TODOs. * Check that number of conv groups must be divisible by the number of merged groups. * Improve error handling in the conv fwd example. * Fix clang-format. * Fix group offsets. * Fix merge problem. * Address feedback from code review. * Fix clang-formatting.
This commit is contained in:
@@ -470,10 +470,10 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
{
|
||||
IndexType NStrideTensorA_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -701,11 +701,11 @@ struct TransformConvFwdToGemm
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
|
||||
{
|
||||
IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_;
|
||||
IndexType HiStride_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -960,12 +960,12 @@ struct TransformConvFwdToGemm
|
||||
CK_TILE_HOST auto MakeADescriptor_M_K() const
|
||||
|
||||
{
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
|
||||
IndexType HiStride_ = Wi_ * G_ * C_;
|
||||
IndexType WiStride_ = G_ * C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
IndexType GStrideTensorA_ = C_;
|
||||
IndexType CStrideTensorA_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -1289,9 +1289,9 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeBDescriptor_N_K() const
|
||||
{
|
||||
IndexType CStrideTensorB_ = 1;
|
||||
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
|
||||
IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
|
||||
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
|
||||
IndexType CStrideTensorB_ = 1;
|
||||
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
@@ -1356,10 +1356,10 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
@@ -1417,11 +1417,11 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_;
|
||||
IndexType HoStride_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
@@ -1482,12 +1482,12 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
|
||||
IndexType HoStride_ = Wo_ * G_ * K_;
|
||||
IndexType WoStride_ = G_ * K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
|
||||
IndexType GStrideTensorC_ = K_;
|
||||
IndexType KStrideTensorC_ = 1;
|
||||
|
||||
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
|
||||
Reference in New Issue
Block a user