Replace CK_TILE_PIPELINE macros with a common enum

This change replaces pipeline macros like CK_TILE_PIPELINE_COMPUTE_V3,
CK_TILE_PIPELINE_MEMORY, etc in the CK Tile examples with a common enum
called GemmPipeline to reduce code duplication.
This commit is contained in:
Emily Martins
2025-10-27 17:26:04 +00:00
committed by Emily Martins
parent afe1ff618d
commit 2ec57a8e70
13 changed files with 220 additions and 238 deletions

View File

@@ -7,12 +7,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
@@ -36,9 +33,9 @@ struct GemmConfigMemory
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3
@@ -56,9 +53,9 @@ struct GemmConfigV3
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4
@@ -77,9 +74,9 @@ struct GemmConfigV4
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma
@@ -97,16 +94,16 @@ struct GemmConfigV3_Wmma
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -115,7 +112,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -124,7 +121,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;