[CK_BUILDER] Add grouped conv fwd ck tile traits (#3183)

* [CK BUILDER] Add grouped conv fwd ck tile traits

* Update instance_traits_tile_grouped_convolution_forward.hpp

* Update grouped_convolution_forward_kernel.hpp
This commit is contained in:
Bartłomiej Kocot
2025-11-11 22:55:33 +01:00
committed by GitHub
parent b145a5fe80
commit 92c1f4981a
18 changed files with 433 additions and 15 deletions

View File

@@ -164,6 +164,13 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_ASYNC";
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();

View File

@@ -170,6 +170,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using Base::PrefetchStages;
using Base::UsePersistentKernel;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V3";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -172,6 +172,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V4";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -99,6 +99,13 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V5";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -159,6 +159,13 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
static constexpr auto is_a_load_tr_v = bool_constant<BasePImpl::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<BasePImpl::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "COMPUTE_V6";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -214,6 +214,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "MEMORY";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t kLdsAlignmentInBytes = 16;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "BASIC_V1";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV2
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "BASIC_V2";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -176,6 +176,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "PRESHUFFLE_V2";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off