[CK_TILE] Switch to universal gemm for batched and grouped gemms (#1919)

* switch to universal gemm for batched and grouped gemms

* added reviewer comments

* fixed grouped gemm tests
This commit is contained in:
jakpiase
2025-03-20 11:17:04 +01:00
committed by GitHub
parent b819c217e4
commit 0e91d32c61
13 changed files with 853 additions and 359 deletions

View File

@@ -56,6 +56,20 @@ struct GemmHostArgs : public GemmProblem
index_t k_batch;
};
struct GemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
@@ -90,20 +104,6 @@ struct GemmKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t k_batch;
};
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
{
return GemmKernelArgs{hostArgs.a_ptr,