mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK_TILE] Blockwise GEMM pipeline v6 - port of v5 from old CK (#2955)
* First checkpoint * Second checkpoint - hot loop scheduler * Third checkpoint - init main operator * Fourth checkpoint - main loop ready * Fifth checkpoint - main loop fix * Sixth checkpoint - ReadWritecompFunc * Seventh checkpoint - Tail finished * [CK_TILE] Blockwise gemm pipeline v5 complete * Working * Working fixes 2 * Rename v5 to v77 temporarily * Data type adjustment * Data type adjustment 2 * [CK_TILE] Blockwise Gemm pipeline v5 add tests * [CK_TILE] Fix calculation error * TEMP: check pipeline * Fix name to V6 * naming and documentation changes * WIP dump * Try fixing v1 * Failing tests v5 * Debugging * Changes v2 * F16 tests working great * Working BlockwiseGemmPipelineV5 as V6 * Cleanup and format * Merging changes part1 * [CK_TILE] Blockwise Gemm Pipeline Comp V5/V6 * Remove commented code * Fix gfx950 build issues * Fix file formatting * Review changes, more concat info, add bf16 bf8 tests * Fix formatting * Add bf16 and bf8 tests --------- Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
@@ -38,6 +38,7 @@ enum struct GemmPipelineType
|
||||
Mem,
|
||||
CompV3,
|
||||
CompV4,
|
||||
CompV6,
|
||||
CompAsync
|
||||
};
|
||||
|
||||
@@ -71,6 +72,15 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::CompV6, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompV6<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV6"; }
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::CompAsync, Problem>
|
||||
{
|
||||
@@ -120,11 +130,13 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
constexpr bool kPadK = PadK;
|
||||
constexpr bool preshuffle = Preshuffle;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 ||
|
||||
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 ||
|
||||
PipelineType == GemmPipelineType::CompAsync);
|
||||
constexpr bool TransposeC = false;
|
||||
static constexpr bool StructuredSparsity = false;
|
||||
static constexpr bool NumWaveGroup = 1;
|
||||
|
||||
// TODO: For now - but this should also be a test parameter
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
@@ -140,8 +152,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
static constexpr bool StructuredSparsity = false;
|
||||
static constexpr bool NumWaveGroup = 1;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
|
||||
Reference in New Issue
Block a user