mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
fix(grouped_gemm): pipeline selection when tail_num varies per group and leads to numerical error (#2863)
* fix(grouped_gemm): numerical errors on gfx950 by correctly calculating the tail num * WIP: add temp config to stress test numerical error correction * refactor: remove comments
This commit is contained in:
@@ -91,7 +91,7 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool Persistent = false;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
@@ -139,6 +139,29 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_V2 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user