[rocm-libraries] ROCm/rocm-libraries#4357 (commit ff3e982)

[CK_TILE] Add support and tests for V6 pipeline in conv fwd
 (#4357)

Added support for conv v6 pipeline in ck tile's convolution forward
kernel. CK Tile v6 pipeline is the equivalent to old ck's V5 pipeline
and should be faster than other pipelines for some cases. This PR also
adds tests inside profiler that's currently inside experimental
directory, so now we should be able to detect regressions easier.
This commit is contained in:
jakpiase
2026-02-08 19:57:53 +00:00
committed by assistant-librarian[bot]
parent 57d26db844
commit 731afe535a
13 changed files with 97 additions and 42 deletions

View File

@@ -207,6 +207,26 @@ struct ConvConfigComputeV5 : public ConvConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>
struct ConvConfigComputeV6 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
static constexpr ck_tile::index_t NumWaveGroups = 1;
};
template <typename PrecType>
struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
{
@@ -310,3 +330,12 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
};