[CK TILE] Implement cschuflle algorithm (#1842)

* [CK TILE] Implement cschuflle algorithm

* Rebase

* Vector store size fixes

* fixes

* Fixes

* fixes

* fmha fix

* fixes

* fixes of fixes
This commit is contained in:
Bartłomiej Kocot
2025-01-30 11:57:39 +01:00
committed by GitHub
parent c5fff071e5
commit 25e2e0f04a
18 changed files with 408 additions and 371 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
@@ -65,9 +65,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
@@ -106,6 +103,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
GemmPipeline::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -244,7 +255,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
public:
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1}; }
void SetUp() override { k_batches_ = {1, 2}; }
template <bool PadM = true, bool PadN = true, bool PadK = true>
void Run(const int M,