[CK TILE] Grouped Convolution Forward Kernel (#2188)

* [CK TILE] Grouped Convolution Forward Kernel

* custom vector size

* fixes

* refactor

* rebase fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2025-06-21 00:44:36 +02:00
committed by GitHub
parent 7378a51b4c
commit cebdee4d9e
17 changed files with 3096 additions and 19 deletions

View File

@@ -27,7 +27,9 @@ template <typename ADataType_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1>
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -48,6 +50,8 @@ struct CShuffleEpilogueProblem
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -80,6 +84,8 @@ struct CShuffleEpilogue
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
@@ -98,6 +104,10 @@ struct CShuffleEpilogue
*/
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{