mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Pass LSE/O strides in kernel argument
This commit is contained in:
@@ -374,9 +374,16 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o);
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -391,11 +398,18 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o);
|
||||
args.batch_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -85,8 +85,18 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_splits;
|
||||
|
||||
ck_tile::index_t row_stride_o_acc;
|
||||
ck_tile::index_t row_stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
|
||||
struct CommonLSEKargs
|
||||
@@ -132,11 +142,18 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
float scale_o,
|
||||
ck_tile::index_t row_stride_o_acc,
|
||||
ck_tile::index_t row_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse_acc,
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o)
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc)
|
||||
{
|
||||
Kargs kargs{{lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
@@ -147,10 +164,17 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
num_splits,
|
||||
row_stride_o_acc,
|
||||
row_stride_o,
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -180,9 +204,16 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
float scale_o,
|
||||
ck_tile::index_t row_stride_o_acc,
|
||||
ck_tile::index_t row_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse_acc,
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o)
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc)
|
||||
{
|
||||
Kargs kargs{{lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
@@ -193,10 +224,17 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
num_splits,
|
||||
row_stride_o_acc,
|
||||
row_stride_o,
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -239,20 +277,18 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
const long_index_t batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
|
||||
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse =
|
||||
@@ -273,10 +309,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
|
||||
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
@@ -285,13 +317,12 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const LSEDataType* lse_acc_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q) +
|
||||
batch_offset_lse_acc;
|
||||
const LSEDataType* lse_acc_ptr =
|
||||
reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc;
|
||||
const OaccDataType* o_acc_ptr =
|
||||
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q * kargs.hdim_v) +
|
||||
batch_offset_o_acc;
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
@@ -301,7 +332,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q),
|
||||
make_tuple(kargs.batch * kargs.nhead * kargs.max_seqlen_q, 1),
|
||||
make_tuple(kargs.split_stride_lse_acc, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -315,8 +346,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
|
||||
make_tuple(
|
||||
kargs.batch * kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v, kargs.hdim_v, 1),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -390,8 +420,8 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
smem_ptr,
|
||||
kargs.num_splits,
|
||||
smem_ptr,
|
||||
kargs.seqlen_q,
|
||||
kargs.max_seqlen_q);
|
||||
}
|
||||
@@ -400,8 +430,8 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return FmhaPipeline{}(lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
smem_ptr,
|
||||
kargs.num_splits,
|
||||
smem_ptr,
|
||||
kargs.seqlen_q,
|
||||
kargs.max_seqlen_q);
|
||||
}
|
||||
|
||||
@@ -82,8 +82,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp,
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
void* smem_ptr,
|
||||
index_t num_splits,
|
||||
void* smem_ptr,
|
||||
index_t real_seqlen_q,
|
||||
index_t max_seqlen_q) const
|
||||
{
|
||||
@@ -311,8 +311,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window,
|
||||
const OaccDramBlockWindow& o_acc_dram_block_window,
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
void* smem_ptr,
|
||||
index_t num_splits,
|
||||
void* smem_ptr,
|
||||
index_t real_seqlen_q,
|
||||
index_t max_seqlen_q) const
|
||||
{
|
||||
@@ -321,8 +321,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_dram_block_window,
|
||||
identity{},
|
||||
identity{},
|
||||
smem_ptr,
|
||||
num_splits,
|
||||
smem_ptr,
|
||||
real_seqlen_q,
|
||||
max_seqlen_q);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user