Merge commit '5f1ad09b610cb0e083f63988479ab022bda70588' into develop

This commit is contained in:
assistant-librarian[bot]
2025-06-13 01:40:16 +00:00
parent dfa4ac5eb3
commit c1e4ebf424
17 changed files with 727 additions and 110 deletions

View File

@@ -14,6 +14,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
@@ -31,6 +32,10 @@
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV5
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV5
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
@@ -51,7 +56,8 @@ struct GemmConfig
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t NumWaveGroups = 1;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
@@ -67,7 +73,8 @@ struct GemmConfig
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t NumWaveGroups = 1;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
@@ -83,7 +90,29 @@ struct GemmConfig
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t NumWaveGroups = 1;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
// Available wavegroups will be split into `NumWaveGroups` and each of these wavegroups
// will be responsible for specific jobs. For instance, perform Global Memory read operations,
// perform block-gemm operation etc...
static constexpr ck_tile::index_t NumWaveGroups = 2;
#endif
static constexpr bool kPadM = false;

View File

@@ -50,7 +50,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent>;
Persistent,
GemmConfig::NumWaveGroups>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -96,7 +97,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -190,7 +193,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}