Pass LSE/O strides in kernel argument

This commit is contained in:
PoYen, Chen
2024-06-11 19:45:21 +00:00
parent df4fc8f26c
commit 9d1243e7fa
3 changed files with 79 additions and 35 deletions

View File

@@ -374,9 +374,16 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o);
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
else
{ // create batch mode kernel arguments
@@ -391,11 +398,18 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.batch_stride_o);
args.batch_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
}();

View File

@@ -85,8 +85,18 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t hdim_v;
ck_tile::index_t num_splits;
ck_tile::index_t row_stride_o_acc;
ck_tile::index_t row_stride_o;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
};
struct CommonLSEKargs
@@ -132,11 +142,18 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o)
ck_tile::index_t batch_stride_o,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
@@ -147,10 +164,17 @@ struct FmhaFwdSplitKVCombineKernel
seqlen_q,
hdim_v,
num_splits,
row_stride_o_acc,
row_stride_o,
nhead_stride_o}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
if constexpr(kStoreLSE)
@@ -180,9 +204,16 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o)
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
@@ -193,10 +224,17 @@ struct FmhaFwdSplitKVCombineKernel
-1, // seqlen will be updated by another pointer
hdim_v,
num_splits,
row_stride_o_acc,
row_stride_o,
nhead_stride_o}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
if constexpr(kStoreLSE)
@@ -239,20 +277,18 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
if constexpr(kStoreLSE)
{
batch_offset_lse =
@@ -273,10 +309,6 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
@@ -285,13 +317,12 @@ struct FmhaFwdSplitKVCombineKernel
}
// for simplicity, batch stride we just modify the pointer
const LSEDataType* lse_acc_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q) +
batch_offset_lse_acc;
const LSEDataType* lse_acc_ptr =
reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc;
const OaccDataType* o_acc_ptr =
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * (kargs.max_seqlen_q * kargs.hdim_v) +
batch_offset_o_acc;
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
@@ -301,7 +332,7 @@ struct FmhaFwdSplitKVCombineKernel
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.num_splits, kargs.seqlen_q),
make_tuple(kargs.batch * kargs.nhead * kargs.max_seqlen_q, 1),
make_tuple(kargs.split_stride_lse_acc, 1),
number<8>{},
number<1>{});
@@ -315,8 +346,7 @@ struct FmhaFwdSplitKVCombineKernel
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
make_tuple(
kargs.batch * kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v, kargs.hdim_v, 1),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
@@ -390,8 +420,8 @@ struct FmhaFwdSplitKVCombineKernel
lse_dram_window,
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
smem_ptr,
kargs.num_splits,
smem_ptr,
kargs.seqlen_q,
kargs.max_seqlen_q);
}
@@ -400,8 +430,8 @@ struct FmhaFwdSplitKVCombineKernel
return FmhaPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
smem_ptr,
kargs.num_splits,
smem_ptr,
kargs.seqlen_q,
kargs.max_seqlen_q);
}

View File

@@ -82,8 +82,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
LSEDramBlockWindowTmp& lse_dram_window_tmp,
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
void* smem_ptr,
index_t num_splits,
void* smem_ptr,
index_t real_seqlen_q,
index_t max_seqlen_q) const
{
@@ -311,8 +311,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window,
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
void* smem_ptr,
index_t num_splits,
void* smem_ptr,
index_t real_seqlen_q,
index_t max_seqlen_q) const
{
@@ -321,8 +321,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_dram_block_window,
identity{},
identity{},
smem_ptr,
num_splits,
smem_ptr,
real_seqlen_q,
max_seqlen_q);
}