Merge commit '60d3e8f504edd25569811b25b4b876d0a504b3b8' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-11 15:11:42 +00:00
parent 269824c6bb
commit 9541fc3ef3
22 changed files with 439 additions and 192 deletions

View File

@@ -353,5 +353,9 @@ int run_grouped_gemm_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
#endif
}

View File

@@ -17,10 +17,6 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -190,6 +186,29 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr bool kPadK = true;
};
template <typename PrecType>
struct GemmConfigComputeV4_Wmma : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
{