mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK TILE] Refactor GemmKernel to be reused by other GEMM related operators (#1730)
* Gemm Kernel Refactor part1 * Gemm Kernel Refactor common gemm pipeline part2 * [CK TILE] Refactor batched gemm to reuse GemmKernel * [CK TILE] Refactor GemmKernel - review changes part1 * [CK TILE] Refactor GemmKernel - references fix * [CK TILE] Refactor GemmKernel - naming changes, add problem * [CK_TILE] Refactor GemmKernel - update tests * [CK_TILE] Refactor GemmKernel - review changes * [CK_TILE] Refactor GemmKernel - update test * [CK_TILE] Refactor GemmKernel - constness fixes * [CK_TILE] Refactor GemmKernel - update tests
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
@@ -79,17 +79,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.p_a,
|
||||
args.p_b,
|
||||
args.p_c,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
args.stride_B,
|
||||
args.stride_C);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
|
||||
Reference in New Issue
Block a user