Replace CK_TILE_PIPELINE macros with a common enum

This change replaces pipeline macros like CK_TILE_PIPELINE_COMPUTE_V3,
CK_TILE_PIPELINE_MEMORY, etc in the CK Tile examples with a common enum
called GemmPipeline to reduce code duplication.
This commit is contained in:
Emily Martins
2025-10-27 17:26:04 +00:00
committed by Emily Martins
parent afe1ff618d
commit 2ec57a8e70
13 changed files with 220 additions and 238 deletions

View File

@@ -11,11 +11,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -87,7 +82,7 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool Persistent = true;
@@ -109,8 +104,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
};
@@ -132,8 +127,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
static constexpr int kBlockPerCu = 2;
};
@@ -155,8 +150,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
static constexpr int kBlockPerCu = 2;
};
@@ -178,12 +173,12 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
@@ -201,12 +196,12 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool kPadK = true;
};
template <typename PrecType>
@@ -226,8 +221,8 @@ struct GemmConfigComputeV4_Wmma : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
static constexpr int kBlockPerCu = 2;
};
@@ -249,18 +244,18 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <ck_tile::index_t PipelineId>
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -269,7 +264,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -278,7 +273,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
@@ -287,7 +282,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;

View File

@@ -11,10 +11,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -44,8 +40,8 @@ struct GemmConfigBase
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet
static constexpr bool Persistent = false; // currently persistent == true is not supported yet
static constexpr bool DoubleSmemBuffer =
@@ -67,10 +63,10 @@ struct GemmConfigMemory : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool Persistent = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool Persistent = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3 : public GemmConfigBase
@@ -88,10 +84,10 @@ struct GemmConfigV3 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4 : public GemmConfigBase
{
@@ -109,10 +105,10 @@ struct GemmConfigV4 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma : public GemmConfigBase
@@ -130,16 +126,16 @@ struct GemmConfigV3_Wmma : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -148,7 +144,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -157,7 +153,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;