mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Change in fwd-splitkv kernel to support num_splits=1 case (#1690)
* Change in fwd-splitkv kernel to support num_splits=1 case
* Update in codegen fwd-splitkv to make num_splits > 1 cases pass
* Specify instance traits in dispatch
* Fix link error for fp8 kernels
---------
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: ce2bdf42a9]
This commit is contained in:
@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
if (t.has_lse) {{
|
||||
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
|
||||
return -1;
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -614,27 +624,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
else:
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/paged-kv kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
@@ -35,6 +35,7 @@ struct FmhaFwdSplitKVKernel
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
@@ -234,8 +235,10 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* lse_acc_ptr,
|
||||
void* o_acc_ptr,
|
||||
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
|
||||
final lse */
|
||||
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
|
||||
o */
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
|
||||
@@ -356,8 +359,10 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* lse_acc_ptr,
|
||||
void* o_acc_ptr,
|
||||
void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
|
||||
final lse */
|
||||
void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
|
||||
o */
|
||||
ck_tile::index_t batch,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
@@ -591,9 +596,9 @@ struct FmhaFwdSplitKVKernel
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
|
||||
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
|
||||
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
|
||||
ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
|
||||
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
|
||||
@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
@@ -48,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = true; // always store LSE (acc)
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kHasUnevenSplits_,
|
||||
|
||||
Reference in New Issue
Block a user