mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Replace CK_TILE_PIPELINE macros with a common enum
This change replaces pipeline macros like CK_TILE_PIPELINE_COMPUTE_V3, CK_TILE_PIPELINE_MEMORY, etc in the CK Tile examples with a common enum called GemmPipeline to reduce code duplication.
This commit is contained in:
committed by
Emily Martins
parent
afe1ff618d
commit
2ec57a8e70
@@ -11,11 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -87,7 +82,7 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool Persistent = true;
|
||||
@@ -109,8 +104,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
@@ -132,8 +127,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -155,8 +150,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -178,12 +173,12 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -201,12 +196,12 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool kPadK = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -226,8 +221,8 @@ struct GemmConfigComputeV4_Wmma : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -249,18 +244,18 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -269,7 +264,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -278,7 +273,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
@@ -287,7 +282,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
|
||||
@@ -11,10 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -44,8 +40,8 @@ struct GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet
|
||||
static constexpr bool Persistent = false; // currently persistent == true is not supported yet
|
||||
static constexpr bool DoubleSmemBuffer =
|
||||
@@ -67,10 +63,10 @@ struct GemmConfigMemory : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3 : public GemmConfigBase
|
||||
@@ -88,10 +84,10 @@ struct GemmConfigV3 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
struct GemmConfigV4 : public GemmConfigBase
|
||||
{
|
||||
@@ -109,10 +105,10 @@ struct GemmConfigV4 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma : public GemmConfigBase
|
||||
@@ -130,16 +126,16 @@ struct GemmConfigV3_Wmma : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -148,7 +144,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -157,7 +153,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
|
||||
Reference in New Issue
Block a user