[CK TILE] Grouped Convolution Forward Kernel (#2188)

* [CK TILE] Grouped Convolution Forward Kernel

* custom vector size

* fixes

* refactor

* rebase fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2025-06-21 00:44:36 +02:00
committed by GitHub
parent 7378a51b4c
commit cebdee4d9e
17 changed files with 3096 additions and 19 deletions

View File

@@ -1,8 +1,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {

View File

@@ -121,7 +121,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = Problem::VectorSizeA;
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
@@ -211,7 +211,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N1 = Problem::VectorSizeB;
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);

View File

@@ -14,7 +14,10 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
@@ -24,6 +27,8 @@ struct GemmPipelineProblemBase
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
static constexpr bool FixedVectorSize = FixedVectorSize_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
@@ -115,7 +120,11 @@ struct GemmPipelineProblemBase
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
if constexpr(FixedVectorSize)
{
return VectorSizeA_;
}
else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
@@ -126,7 +135,11 @@ struct GemmPipelineProblemBase
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
if constexpr(FixedVectorSize)
{
return VectorSizeB_;
}
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
@@ -153,13 +166,19 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
ComputeDataType_,
FixedVectorSize_,
VectorSizeA_,
VectorSizeB_>;
template <typename ADataType_,
typename BDataType_,
@@ -169,7 +188,10 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
@@ -179,6 +201,10 @@ struct UniversalGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;

View File

@@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// Tile: MPerBlock X KPerBlock
@@ -461,10 +462,11 @@ struct UniversalGemmBasePolicy
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// Tile: KPerBlock X NPerBlock