From 443a528adc6774133f2d4e1cabae630eacca83fe Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 7 Aug 2024 04:27:15 +0000 Subject: [PATCH] Add block_table kernel args for appendkv kernel --- example/ck_tile/01_fmha/fmha_fwd.hpp | 6 ++++++ .../ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 5b3d0989e0..cd05efa85b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -502,6 +502,9 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) args.rotary_cos_ptr, args.rotary_sin_ptr, args.rotary_dim, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, args.stride_q, args.stride_k, args.stride_knew, @@ -532,6 +535,9 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) args.rotary_cos_ptr, args.rotary_sin_ptr, args.rotary_dim, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, args.stride_q, args.stride_k, args.stride_knew, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index dea410e4b0..b9de3f044e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -97,11 +97,11 @@ struct FmhaFwdAppendKVKernel // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case ck_tile::index_t nhead_ratio_qk; - /* + const void* block_table_ptr; ck_tile::index_t batch_stride_block_table; ck_tile::index_t page_block_size; - */ + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_knew; @@ -160,6 +160,9 @@ struct FmhaFwdAppendKVKernel const void* rotary_cos_ptr, const void* rotary_sin_ptr, ck_tile::index_t rotary_dim, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, @@ -188,6 +191,9 @@ struct FmhaFwdAppendKVKernel hdim_v, num_head_q, nhead_ratio_qk, + block_table_ptr, + batch_stride_block_table, + page_block_size, stride_q, stride_k, stride_knew, @@ -233,6 +239,9 @@ struct FmhaFwdAppendKVKernel const void* rotary_cos_ptr, const void* rotary_sin_ptr, ck_tile::index_t rotary_dim, + const void* block_table_ptr, + ck_tile::index_t batch_stride_block_table, + ck_tile::index_t page_block_size, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, @@ -258,6 +267,9 @@ struct FmhaFwdAppendKVKernel hdim_v, num_head_q, nhead_ratio_qk, + block_table_ptr, + batch_stride_block_table, + page_block_size, stride_q, stride_k, stride_knew,