Merge flatmm Operator with universal gemm (#2434)

* Initial commit

* Adding new tile partitioner to flatmm

* intermediate changes

* debugging kernels

* Updating flatmm example to universal gemm example

* updated flatmm kernel to run via gemmKernel

* update universal gemm to incorporate flatmm

* debug

* Fix flatmm call

* Fixing other kernels and tests for API changes

* clang formatted

* fixing gemm tests

* added test for flatmm and simplify kernel arguments

* adding flatmm test

* fix test for flatmm

* simplify gemm kernel with flatmm

* remove flatmm related files

* addressing review comments and code clean up

* resolving empty file

* resolving empty file

* clang formatted

* addressing review comments

* enable persistent kernel for flatmm

* reverted the removed files for flatmm

* reverted the removed files for flatmm

* changed flatmm to weightPReshuffle; removed the _1 added in teh faltmm example

* some more renames

* clang formatted
This commit is contained in:
Khushbu Agarwal
2025-07-11 08:27:55 -07:00
committed by GitHub
parent 45904b8fd7
commit d239b91fd5
34 changed files with 2736 additions and 338 deletions

View File

@@ -90,7 +90,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
// TODO: expose tile size through test t-param ?
template <bool PadM, bool PadN, bool PadK>
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
const ck_tile::stream_config& s)
{
@@ -107,9 +107,10 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
constexpr bool preshuffle = Preshuffle;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
@@ -131,7 +132,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
static constexpr bool StructuredSparsity = false;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
static constexpr bool NumWaveGroup = 1;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
@@ -140,7 +143,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
CLayout,
TransposeC,
StructuredSparsity,
Persistent>;
Persistent,
NumWaveGroup,
preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -261,7 +266,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
}
}
template <bool PadM = true, bool PadN = true, bool PadK = true>
template <bool PadM = true, bool PadN = true, bool PadK = true, bool Preshuffle = false>
void Run(const int M,
const int N,
const int K,
@@ -271,11 +276,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
{
for(auto kb : k_batches_)
{
RunSingle<PadM, PadN, PadK>(M, N, K, StrideA, StrideB, StrideC, kb);
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
template <bool PadM, bool PadN, bool PadK>
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void RunSingle(const int M,
const int N,
const int K,
@@ -352,7 +357,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
args.stride_B = stride_B;
args.stride_E = stride_C;
invoke_gemm<PadM, PadN, PadK>(args, ck_tile::stream_config{nullptr, false});
invoke_gemm<PadM, PadN, PadK, Preshuffle>(args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;