mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK TILE] Refactor GemmKernel to be reused by other GEMM related operators (#1730)
* Gemm Kernel Refactor part1 * Gemm Kernel Refactor common gemm pipeline part2 * [CK TILE] Refactor batched gemm to reuse GemmKernel * [CK TILE] Refactor GemmKernel - review changes part1 * [CK TILE] Refactor GemmKernel - references fix * [CK TILE] Refactor GemmKernel - naming changes, add problem * [CK_TILE] Refactor GemmKernel - update tests * [CK_TILE] Refactor GemmKernel - review changes * [CK_TILE] Refactor GemmKernel - update test * [CK_TILE] Refactor GemmKernel - constness fixes * [CK_TILE] Refactor GemmKernel - update tests
This commit is contained in:
@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
gemm_basic_args args;
|
||||
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.kbatch = kbatch;
|
||||
ck_tile::GemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
|
||||
Reference in New Issue
Block a user