mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Switch to universal gemm for batched and grouped gemms (#1919)
* switch to universal gemm for batched and grouped gemms * added reviewer comments * fixed grouped gemm tests
This commit is contained in:
@@ -56,6 +56,20 @@ struct GemmHostArgs : public GemmProblem
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernel
|
||||
{
|
||||
@@ -90,20 +104,6 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
{
|
||||
return GemmKernelArgs{hostArgs.a_ptr,
|
||||
|
||||
Reference in New Issue
Block a user