mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Change in fwd-splitkv kernel to support num_splits=1 case (#1690)
* Change in fwd-splitkv kernel to support num_splits=1 case * Update in codegen fwd-splitkv to make num_splits > 1 cases pass * Specify instance traits in dispatch * Fix link error for fp8 kernels --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -35,6 +35,7 @@ struct FmhaFwdSplitKVKernel
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
@@ -234,8 +235,10 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* lse_acc_ptr,
|
||||
void* o_acc_ptr,
|
||||
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
|
||||
final lse */
|
||||
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
|
||||
o */
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
|
||||
@@ -356,8 +359,10 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* lse_acc_ptr,
|
||||
void* o_acc_ptr,
|
||||
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
|
||||
final lse */
|
||||
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
|
||||
o */
|
||||
ck_tile::index_t batch,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
@@ -591,9 +596,9 @@ struct FmhaFwdSplitKVKernel
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
|
||||
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
|
||||
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
|
||||
ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
|
||||
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
|
||||
Reference in New Issue
Block a user