diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index f6a0e93262..faf98e95e0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import ( ) +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", @@ -51,7 +59,6 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_bias}, false, {F_lse}, - {F_dropout}, {F_squant}, {F_pagedkv}, kHasUnevenSplits, @@ -64,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, @@ -99,7 +105,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -227,9 +233,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); @@ -237,12 +243,78 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F """ @dataclass -class FmhaFwdSplitKVApiTrait(FmhaFwdApiTrait): +class FmhaFwdSplitKVApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int + vlayout : str + mask : str + bias : str # + lse : str # + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str pagedkv : str @property def name(self) -> str: - return FmhaFwdApiTrait.name + f'-{self.pagedkv}' + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.dvpad}-{self.pagedkv}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False @dataclass class FmhaFwdSplitKVPipeline: @@ -255,7 +327,6 @@ class FmhaFwdSplitKVPipeline: F_dvpad : str # F_bias : str # true/false F_lse : str # - F_dropout : str # F_squant : str # F_pagedkv : str # t/f F_mask : str # value from MASK_MAP @@ -279,7 +350,6 @@ class FmhaFwdSplitKVPipeline: else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' - if self.F_dropout == 't' : n += '_dropout' if self.F_squant == 't' : n += '_squant' if self.F_pagedkv == 't' : n += '_pagedkv' return n @@ -335,7 +405,7 @@ class FmhaFwdSplitKVApiPool: inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -396,7 +466,6 @@ class FmhaFwdSplitKVKernel: F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy = self.F_tile.F_occupancy, @@ -431,7 +500,6 @@ class FmhaFwdSplitKVKernel: mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, squant=self.F_pipeline.F_squant, pagedkv=self.F_pipeline.F_pagedkv, spad=self.F_pipeline.F_spad, @@ -525,26 +593,25 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - # splitkv kernel donot support dropout - for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]): + for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse/dropout/paged-kv kernels + # no need lse/paged-kv kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, 'f', mask)) else: diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index ba04c73342..263ea6d48b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -504,6 +504,14 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cerr << "num_splits greater than 128 is not supported" << std::endl; return false; } +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < p_drop && (1 < num_splits || 0 < page_block_size)) + { + std::cerr << "dropout is not supoprted in split-kv kernels. ignoring the option 'p_drop'" + << std::endl; + p_drop = 0.0f; + } +#endif auto get_lengths = [&](bool permute, ck_tile::index_t b /*batch*/, @@ -839,17 +847,6 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - auto fmha_traits = fmha_fwd_traits{hdim_q, - hdim_v, - data_type, - mode == mode_enum::group, - is_v_rowmajor, - mask.type, - bias.type, - lse, - p_drop > 0.0f, - squant}; - auto p_compute_element_func = [&]() { if constexpr(std::is_same_v) return ck_tile::scales{scale_p}; @@ -991,9 +988,30 @@ bool run(const ck_tile::ArgParser& arg_parser) #if CK_TILE_FMHA_FWD_SPLITKV_API if(1 < num_splits || 0 < page_block_size) { - return fmha_fwd_splitkv(fmha_traits, fmha_args, stream_config); + auto fmha_splitkv_traits = fmha_fwd_splitkv_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + is_v_rowmajor, + mask.type, + bias.type, + lse, + squant}; + + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_args, stream_config); } #endif + auto fmha_traits = fmha_fwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + is_v_rowmajor, + mask.type, + bias.type, + lse, + p_drop > 0.0f, + squant}; + return fmha_fwd(fmha_traits, fmha_args, stream_config); }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 54f426d14e..dcec1c68ce 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -314,7 +314,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.k_ptr, args.v_ptr, args.bias_ptr, - args.rand_val_ptr, args.lse_acc_ptr, args.o_acc_ptr, args.batch, @@ -336,13 +335,11 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.stride_k, args.stride_v, args.stride_bias, - args.stride_randval, args.stride_o_acc, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, - args.nhead_stride_randval, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, args.batch_stride_k, @@ -353,10 +350,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.mask_type); } else { // create batch mode kernel arguments @@ -364,7 +358,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.k_ptr, args.v_ptr, args.bias_ptr, - args.rand_val_ptr, args.lse_acc_ptr, args.o_acc_ptr, args.batch, @@ -385,30 +378,24 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.stride_k, args.stride_v, args.stride_bias, - args.stride_randval, args.stride_o_acc, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, - args.nhead_stride_randval, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, - args.batch_stride_randval, args.batch_stride_lse_acc, args.batch_stride_o_acc, args.split_stride_lse_acc, args.split_stride_o_acc, args.window_size_left, args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.mask_type); } }(); @@ -627,7 +614,6 @@ template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; @@ -746,8 +731,22 @@ struct fmha_fwd_traits }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); -using fmha_fwd_splitkv_traits = fmha_fwd_traits; -float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, fmha_fwd_splitkv_args, const ck_tile::stream_config&); +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, + fmha_fwd_splitkv_args, + const ck_tile::stream_config&); struct fmha_fwd_appendkv_traits { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 5dd8781e09..e08c9248b0 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; - using RandValOutputDataType = - ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; using OaccDataType = remove_cvref_t; @@ -46,7 +44,6 @@ struct FmhaFwdSplitKVKernel static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; - static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; using FmhaMask = ck_tile::remove_cvref_t; @@ -87,8 +84,7 @@ struct FmhaFwdSplitKVKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "") + (kDoFp8StaticQuant ? "_squant" : "") + - (kIsPagedKV ? "_pagedkv" : "" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -180,34 +176,6 @@ struct FmhaFwdSplitKVKernel float scale_p; }; - struct CommonDropoutKargs - { - void init_dropout(const float p_drop, - const std::tuple& drop_seed_offset) - { - float p_undrop = 1.0 - p_drop; - p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - rp_undrop = 1.0 / p_undrop; - - drop_seed = std::get<0>(drop_seed_offset); - drop_offset = std::get<1>(drop_seed_offset); - } - float rp_undrop = 1; - uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - bool is_store_randval = false; - uint64_t drop_seed = 1; - uint64_t drop_offset = 0; - void* rand_val_ptr = nullptr; - - ck_tile::index_t stride_randval = 0; - ck_tile::index_t nhead_stride_randval = 0; - }; - struct BatchModeDropoutKargs : CommonDropoutKargs - { - ck_tile::index_t batch_stride_randval = 0; - }; - struct BatchModeKargs : CommonKargs, std::conditional_t>>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -232,8 +199,7 @@ struct FmhaFwdSplitKVKernel AlibiKargs, EmptyKargs<0>>>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -251,7 +217,6 @@ struct FmhaFwdSplitKVKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, - void* rand_val_ptr, void* lse_acc_ptr, void* o_acc_ptr, ck_tile::index_t batch, @@ -272,30 +237,24 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, k_ptr, @@ -335,7 +294,6 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout batch_stride_q, batch_stride_k, batch_stride_v}; @@ -362,15 +320,6 @@ struct FmhaFwdSplitKVKernel { kargs.scale_p = scale_p; } - if constexpr(kHasDropout) - { - kargs.init_dropout(p_drop, drop_seed_offset); - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.batch_stride_randval = batch_stride_randval; - kargs.is_store_randval = s_randval; - } return kargs; } @@ -381,7 +330,6 @@ struct FmhaFwdSplitKVKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, - void* rand_val_ptr, void* lse_acc_ptr, void* o_acc_ptr, ck_tile::index_t batch, @@ -403,13 +351,11 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, @@ -420,10 +366,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, k_ptr, @@ -463,7 +406,6 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr), @@ -491,14 +433,6 @@ struct FmhaFwdSplitKVKernel { kargs.scale_p = scale_p; } - if constexpr(kHasDropout) - { - kargs.init_dropout(p_drop, drop_seed_offset); - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.is_store_randval = s_randval; - } return kargs; } @@ -531,11 +465,10 @@ struct FmhaFwdSplitKVKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; const long_index_t batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; const long_index_t batch_offset_o_acc = @@ -561,10 +494,6 @@ struct FmhaFwdSplitKVKernel { batch_offset_bias = query_start * kargs.stride_bias + key_start; } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -597,11 +526,6 @@ struct FmhaFwdSplitKVKernel { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } } auto k_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { @@ -854,62 +778,6 @@ struct FmhaFwdSplitKVKernel return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); }(); - // dropout - float rp_undrop = 1; - uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - uint64_t drop_seed = 0; - uint64_t drop_offset = 0; - bool is_store_randval = false; - - if constexpr(kHasDropout) - { - rp_undrop = kargs.rp_undrop; - p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; - drop_seed = kargs.drop_seed; - drop_offset = kargs.drop_offset; - is_store_randval = kargs.is_store_randval; - } - BlockDropout dropout(i_batch, - i_nhead, - kargs.num_head_q, - drop_seed, - drop_offset, - rp_undrop, - p_undrop_in_uint8_t, - is_store_randval); - - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); - FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -966,7 +834,6 @@ struct FmhaFwdSplitKVKernel identity{}, // v_element_func bias_dram_window, identity{}, // bias_element_func - randval_dram_window, lse_acc_dram_window, identity{}, // lse_element_func identity{}, // s_acc_element_func @@ -978,7 +845,6 @@ struct FmhaFwdSplitKVKernel position_encoding, kargs.scale_s, smem_ptr, - dropout, k_tile_navigator, v_tile_navigator); } @@ -988,7 +854,6 @@ struct FmhaFwdSplitKVKernel k_dram_window, v_dram_window, bias_dram_window, - randval_dram_window, lse_acc_dram_window, kargs.num_splits, i_split_, @@ -996,7 +861,6 @@ struct FmhaFwdSplitKVKernel position_encoding, kargs.scale_s, smem_ptr, - dropout, k_tile_navigator, v_tile_navigator); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index d70e6322e6..450343b828 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -15,19 +14,18 @@ namespace ck_tile { template struct BlockFmhaFwdSplitKVPipelineQRKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -49,8 +47,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = true; // always store LSE (acc) - static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; @@ -110,7 +107,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename RandValDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -132,7 +128,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile const LSEaccElementFunction& lse_acc_element_func, const SAccElementFunction& s_acc_element_func, @@ -144,7 +139,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout, const KTileWindowNavigator& k_tile_navigator, const VTileWindowNavigator& v_tile_navigator) const { @@ -264,9 +258,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.MakeRandvalDramWindow( - randval_dram_block_window_tmp, adjusted_seqlen_k_start); - auto [i_block1, v_dram_window] = v_tile_navigator.make_tile_window( v_dram_block_window_tmp, {0, adjusted_seqlen_k_start}, // TODO: hdim split? @@ -526,12 +517,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS }); }); - if constexpr(kHasDropout) - { - dropout.Run( - smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); - } - block_sync_lds(); if constexpr(std::is_same_v) { @@ -649,7 +634,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename RandValDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding, typename KTileWindowNavigator, @@ -659,7 +643,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile index_t num_splits, index_t i_split, @@ -667,7 +650,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout, const KTileWindowNavigator& k_tile_navigator, const VTileWindowNavigator& v_tile_navigator) const { @@ -679,7 +661,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS identity{}, bias_dram_block_window_tmp, identity{}, - randval_dram_block_window_tmp, lse_acc_dram_block_window_tmp, identity{}, identity{}, @@ -691,7 +672,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS position_encoding, scale_s, smem_ptr, - dropout, k_tile_navigator, v_tile_navigator); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp index da9a973bba..d99bd058d1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp @@ -7,7 +7,6 @@ #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" -#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -16,19 +15,18 @@ namespace ck_tile { template struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using RandValOutputDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -54,8 +52,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = true; // always store LSE (acc) - static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; @@ -142,7 +139,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile const LSEaccElementFunction& lse_acc_element_func, const SAccElementFunction& s_acc_element_func, @@ -153,8 +149,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr, - BlockDropout& dropout) const + void* smem_ptr) const { static_assert( std::is_same_v> && @@ -302,9 +297,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); - auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -585,17 +577,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync }); }); - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); - } - const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -726,7 +707,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename RandValDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto @@ -734,15 +714,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr, - BlockDropout& dropout) const + void* smem_ptr) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -752,7 +730,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync identity{}, bias_dram_block_window_tmp, identity{}, - randval_dram_block_window_tmp, lse_acc_dram_block_window_tmp, identity{}, identity{}, @@ -763,8 +740,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync mask, position_encoding, scale_s, - smem_ptr, - dropout); + smem_ptr); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 95d2898355..a2c9faea88 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -54,39 +54,50 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -template -struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem +template +struct BlockFmhaFwdSplitKVPipelineProblem { - static constexpr bool kIsPagedKV = Traits::kIsPagedKV; - static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr bool kIsPagedKV = Traits::kIsPagedKV; + static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + CK_TILE_HOST_DEVICE static constexpr std:: + enable_if_t, ck_tile::index_t> + GetSmemSizeDropout() { if constexpr(Problem::kHasDropout) { @@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout(...) + { + return 0; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index ed2b9ca564..9a9196f273 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -33,32 +33,31 @@ struct TileFmhaTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -template -struct TileFmhaFwdSplitKVTraits : TileFmhaTraits + index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> +struct TileFmhaFwdSplitKVTraits { - static constexpr bool kIsPagedKV = kIsPagedKV_; + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kIsPagedKV = kIsPagedKV_; // determine if some split (length) is not divisible by tile size static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; }; template