From 19ef22e5673b7896b267f8abd8dba89bccd9e4f8 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 12 Aug 2025 17:02:52 +0800 Subject: [PATCH] [CK_TILE] FMHA BWD Decode Pipeline (#2643) * Fix distr * Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr * decode 16x16 o2 [ROCm/composable_kernel commit: 8e1eb0c1ee36cad0292c960fc346625a0d82a167] --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 207 +++-- example/ck_tile/01_fmha/fmha_bwd.cpp | 22 + example/ck_tile/01_fmha/fmha_bwd.hpp | 36 +- .../ops/epilogue/default_2d_epilogue.hpp | 4 +- include/ck_tile/ops/fmha.hpp | 1 + .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 119 ++- ...ck_fmha_bwd_dq_dk_dv_pipeline_selector.hpp | 6 +- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 743 ++++++++++++++++++ ...mha_bwd_pipeline_trload_default_policy.hpp | 65 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 9 +- 11 files changed, 1051 insertions(+), 165 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 9e15a822ef..6fca800c90 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -127,5 +127,7 @@ PIPELINE_ENUM_MAP = { BOOL_MAP = { "t" : "true", - "f" : "false" + "f" : "false", + True : "true", + False : "false", } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8ca917cb6c..bb3a0587e7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import fnmatch import itertools from pathlib import Path -from typing import List, Optional, Tuple, Dict, Literal +from typing import List, Tuple, Dict, Literal, Any from collections import defaultdict from codegen.cmake_config import * @@ -31,6 +31,7 @@ using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>; using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>; +using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>; // TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape // G0&G2 -> GSdP @@ -46,7 +47,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + fmha_warp_tile2_{F_idx}, + {F_maxq}>; using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits>; +using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, + false, + {F_dpad}>>; + using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel; + fmha_bwd_dv_epilogue_{F_idx}, + fmha_bwd_dq_epilogue_{F_idx}>; using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, @@ -115,7 +124,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dpad}, {F_dvpad}, {F_deterministic}, - {F_trload}>; + {F_trload}, + {F_maxq}>; #include @@ -144,6 +154,13 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_co ck_tile::stream_config{{s.stream_id_}}); }} +template <> +int fmha_bwd_dq_dk_dv_maxq_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::kMaxSeqLenQ; +}} + template <> std::string fmha_bwd_dq_dk_dv_get_name_() {{ @@ -159,13 +176,25 @@ FMHA_BWD_API=""" template float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} - ); + if constexpr (!std::is_same_v) + {{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + ); + }} + else + {{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); + }} }} template <> @@ -177,28 +206,25 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf }} """ -FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ -{F_body} - }} -""" +def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str: + lines = [ + f"{'if' if if_ == 0 else 'else if'}({F_cond})", + "{", + *[' ' + line for line in F_body.split('\n') if line.strip() != ''], + "}", + ] + return '\n'.join(' ' * indent + line for line in lines) + '\n' -FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_body} - }} -""" -FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_body} - }} -""" -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; - r = fmha_bwd_(s, a); - return r; - }} +FMHA_BWD_API_INNER_DISPATCH=""" +{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; + r = fmha_bwd_>(s, a); + return r; +}} """ # M0 size for 1d kernels (dot/convert) @@ -237,11 +263,13 @@ class FmhaBwdDQDKDVTileSize: F_wn1 : int # warp size along n in gemm1/gemm3 F_wk1 : int # warp size along k in gemm1/gemm3 F_occupancy : int # occupancy + max_seq_q : int = 0 + @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}" + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}" @dataclass(frozen=True) class FmhaBwdDQDKDVKernel: @@ -301,6 +329,7 @@ class FmhaBwdDQDKDVKernel: F_mode = MODE_MAP[self.F_mode], F_deterministic = BOOL_MAP[self.F_deterministic], F_trload = BOOL_MAP[self.F_trload], + F_maxq = self.F_tile.max_seq_q ) @property @@ -345,21 +374,23 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. -def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]: +def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': - return { - '32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - '64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - '128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - # '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - '256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - } + return [ + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + ] elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': - return { - '128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), - } + return [ + FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), + FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), + ] else: - return None + return [] FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -537,6 +568,7 @@ class FmhaBwdConvertQGradKernel: F_mode : str # value from MODE_MAP F_occupancy : int # F_deterministic : str # + disabled : bool # sometimes this kernel is not used @property def template(self) -> str: @@ -590,7 +622,7 @@ class FmhaBwdApiTrait: dvpad : str deterministic : str mask_impl : str - tr_load : bool + tr_load : str @property def bm0(self) -> int: @@ -650,17 +682,17 @@ class FmhaBwdApiTrait: return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), - F_deterministic=self.deterministic) + F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) class FmhaBwdApiPool: def __init__(self, mask_impl): - self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))) self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: # TODO: do we need to check duplication? - self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait)) + self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait)) @staticmethod def if_(i: int) -> str: @@ -675,40 +707,68 @@ class FmhaBwdApiPool: F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load]) + F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, + F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled]) i += 1 return inners + @staticmethod + def trload_sort_key(tf): + return 0 if tf == 't' else 1 # sort 't' before 'f' + + @staticmethod + def max_seq_q_sort_key(max_seq_q): + return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end + + @staticmethod + def max_seq_q_cond(max_seq_q: int) -> str: + if max_seq_q == 0: + return 'true /* no seqlen_q limit */' + else: + return f'a.seqlen_q <= {max_seq_q}' + + @staticmethod + def dtype_cond(dtype: str) -> str: + return f't.data_type.compare("{dtype}") == 0' + + @staticmethod + def hdim_cond(hdim: int) -> str: + return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}' + @property def api(self) -> str: tr_load_cond_map = { "t": "has_load_tr", - "f": "true" + "f": "true /* no trload requirement */" } per_tr_load = '' - for tr_load in ["t", "f"]: - per_dtypes = '' - for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]): - per_hdim_case = '' - for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]): - traits = self.dq_dk_dv_pool[tr_load][dtype][hdim] - inners = self._api_innders(traits) - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners) - per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case) - per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes) + for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key): + per_max_seq_q = '' + for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key): + per_dtypes = '' + for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]): + per_hdim_case = '' + for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]): + traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim] + inners = self._api_innders(traits) + per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners) + per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case) + per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes) + per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4) if not per_tr_load: # empty string we add some ignore to suppress warning in api per_tr_load += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) + result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) + return result.replace('\n\n', '\n') def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: if filter_list == '': filter_list = '*@*@*' - filter_list = filter_list.split('@') - filter_list.extend(['*'] * (3 - len(filter_list))) - filter_dot_do_o = filter_list[0] - filter_convert_dq = filter_list[1] - filter_dq_dk_dv = filter_list[2] + filters = filter_list.split('@') + filters.extend(['*'] * (3 - len(filters))) + filter_dot_do_o = filters[0] + filter_convert_dq = filters[1] + filter_dq_dk_dv = filters[2] # use dict as ordered set gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} @@ -717,14 +777,14 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm api_pool = FmhaBwdApiPool(mask_impl) for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load) - if d is None: - continue - for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): - tile = d[hdim_str] - hdim = int(hdim_str) + tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) + for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" + hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): continue + if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0: + continue if ((bias == "no" or bias == "alibi") and dbias == "t"): continue if ("wg32" in dropout): @@ -788,7 +848,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue gen_dot_do_o[t.dot_do_o_kernel] = True gen_dq_dk_dv[t.dq_dk_dv_kernel] = True - gen_convert_dq[t.convert_dq_kernel] = True + if not t.convert_dq_kernel.disabled: + gen_convert_dq[t.convert_dq_kernel] = True api_pool.register_dq_dk_dv_traits(t) return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index b6de5ea621..9c2907778f 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -793,6 +793,14 @@ bool run(const ck_tile::ArgParser& arg_parser) } } + // set to bad values to check if the kernel writes to these buffers + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + dq_buf.ToDevice(dq_host.data()); + dk_buf.ToDevice(dk_host.data()); + dv_buf.ToDevice(dv_host.data()); + o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); dq_buf.SetZero(); @@ -801,6 +809,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + + printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, " + "bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n", + fmha_traits.hdim_q, + fmha_traits.hdim_v, + fmha_traits.data_type.c_str(), + fmha_traits.is_group_mode, + static_cast(fmha_traits.mask_type), + static_cast(fmha_traits.bias_type), + fmha_traits.has_dbias, + fmha_traits.has_dropout, + fmha_traits.is_store_randval, + fmha_traits.is_deterministic); + fflush(stdout); fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index bd63c96eb1..8d35b2d12c 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -156,6 +156,12 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { + constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0; + const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr; + const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq; + const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq; + const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq; + // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { @@ -170,7 +176,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.dk_ptr, args.dv_ptr, args.dbias_ptr, - args.dq_acc_ptr, + dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, @@ -185,7 +191,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, - args.stride_dq_acc, + stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -196,7 +202,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, - args.nhead_stride_dq_acc, + nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, @@ -220,7 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.dk_ptr, args.dv_ptr, args.dbias_ptr, - args.dq_acc_ptr, + dq_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, @@ -234,7 +240,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, - args.stride_dq_acc, + stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -245,7 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, - args.nhead_stride_dq_acc, + nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, @@ -256,7 +262,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.batch_stride_randval, args.batch_stride_do, args.batch_stride_lsed, - args.batch_stride_dq_acc, + batch_stride_dq, args.batch_stride_dk, args.batch_stride_dv, args.batch_stride_dbias, @@ -365,20 +371,10 @@ template + bool kUseTrLoad_, + ck_tile::index_t MaxSeqLenQ_> struct fmha_bwd_dq_dk_dv_traits_ { - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - using FmhaMask = ck_tile::remove_cvref_t; - using FmhaDropout = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kHasBiasGrad = kHasBiasGrad_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kIsDeterministic = kIsDeterministic_; - static constexpr bool kUseTrLoad = kUseTrLoad_; }; template @@ -389,6 +385,8 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_dq_dk_dv_get_name_(); +template +int fmha_bwd_dq_dk_dv_maxq_(); template struct fmha_bwd_dot_do_o_traits_ diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index ff41ac0d61..fdbe2e7a6d 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -73,7 +73,7 @@ struct Default2DEpilogue // how do we fix this ? template CK_TILE_DEVICE auto - operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) + operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const { // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) @@ -105,7 +105,7 @@ struct Default2DEpilogue CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, const DsDramWindows& /* unused */, - void* = nullptr) + void* = nullptr) const { return operator()(o_dram_window_tmp, o_acc_tile); } diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 276ec4852f..d8dd5db12e 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -26,6 +26,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 595e2cfccf..8750c8b377 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp" #include #include @@ -26,14 +27,22 @@ namespace ck_tile { -template +template struct FmhaBwdDQDKDVKernel { using FmhaPipeline = ck_tile::remove_cvref_t; using KGradEpiloguePipeline = ck_tile::remove_cvref_t; using VGradEpiloguePipeline = ck_tile::remove_cvref_t; + using QGradEpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static constexpr bool kUseQrQtrDorPipeline = + ck_tile::fmha_bwd_qr_qtr_dor_pipeline_c; + static_assert(!kUseQrQtrDorPipeline || !std::is_same_v, + "QrQtrDorPipeline needs QGradEpiloguePipeline"); using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; @@ -63,6 +72,8 @@ struct FmhaBwdDQDKDVKernel static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval; static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic; static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad; + static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ; + static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0)); #if defined(__gfx950__) static constexpr bool kIsAvialable = true; #else @@ -128,7 +139,7 @@ struct FmhaBwdDQDKDVKernel const void* lse_ptr; const void* do_ptr; const void* d_ptr; - void* dq_acc_ptr; + void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline void* dk_ptr; void* dv_ptr; @@ -335,7 +346,7 @@ struct FmhaBwdDQDKDVKernel void* dk_ptr, void* dv_ptr, void* dbias_ptr, - void* dq_acc_ptr, + void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, @@ -482,7 +493,7 @@ struct FmhaBwdDQDKDVKernel } } - if constexpr(kIsDeterministic) + if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline) { kargs.split_stride_dq_acc = split_stride_dq_acc; } @@ -640,7 +651,9 @@ struct FmhaBwdDQDKDVKernel GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) { return dim3( - ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_); + kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), + nhead_, + batch_size_); } CK_TILE_DEVICE static constexpr auto GetTileIndex() @@ -735,10 +748,9 @@ struct FmhaBwdDQDKDVKernel // # of required blocks is different in each groups, terminate unnecessary blocks // earlier - if(kargs.seqlen_k <= i_n0) - { - return; - } + if constexpr(!kUseQrQtrDorPipeline) + if(kargs.seqlen_k <= i_n0) + return; } else { @@ -786,12 +798,10 @@ struct FmhaBwdDQDKDVKernel const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + static_cast(i_nhead) * kargs.nhead_stride_do + batch_offset_do; - KGradDataType* dk_ptr = reinterpret_cast(kargs.dk_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_dk + - batch_offset_dk; - VGradDataType* dv_ptr = reinterpret_cast(kargs.dv_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_dv + - batch_offset_dv; + auto dk_ptr = reinterpret_cast(kargs.dk_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk; + auto dv_ptr = reinterpret_cast(kargs.dv_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv; // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window const auto q_dram_naive = make_naive_tensor_view( @@ -868,8 +878,11 @@ struct FmhaBwdDQDKDVKernel {0, 0}); auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() { - AccDataType* dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + [&]() { - if constexpr(kIsDeterministic) + constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic; + using DType = std::conditional_t; + + auto dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + [&]() { + if constexpr(kUseKSplit) return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + static_cast(i_tile_n_) * kargs.split_stride_dq_acc + batch_offset_dq_acc; @@ -878,7 +891,7 @@ struct FmhaBwdDQDKDVKernel batch_offset_dq_acc; }(); - constexpr auto DstInMemOp = conditional_expr( + constexpr auto DstInMemOp = conditional_expr( memory_operation_enum::set, memory_operation_enum::atomic_add); const auto dq_acc_dram_naive = make_naive_tensor_view( @@ -1063,25 +1076,6 @@ struct FmhaBwdDQDKDVKernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); - auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - do_dram_window, - lse_dram_window, - d_dram_window, - dq_dram_window, - dbias_dram_window, - mask, - position_encoding, - kargs.raw_scale, - kargs.scale, - rp_undrop, - scale_rp_undrop, - smem_ptr, - dropout); - auto dk_dram = [&]() { const auto dk_dram_naive = make_naive_tensor_view( dk_ptr, @@ -1119,9 +1113,56 @@ struct FmhaBwdDQDKDVKernel dv_dram, make_tuple(number{}, number{}), {i_n0, 0}); + if constexpr(!kUseQrQtrDorPipeline) + { + auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + do_dram_window, + lse_dram_window, + d_dram_window, + dq_dram_window, + dbias_dram_window, + mask, + position_encoding, + kargs.raw_scale, + kargs.scale, + rp_undrop, + scale_rp_undrop, + smem_ptr, + dropout); - KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); - VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile); + KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); + VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile); + } + else + { + FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + do_dram_window, + lse_dram_window, + d_dram_window, + dq_dram_window, + dk_dram_window, + dv_dram_window, + dbias_dram_window, + QGradEpiloguePipeline{}, + KGradEpiloguePipeline{}, + VGradEpiloguePipeline{}, + mask, + position_encoding, + kargs.raw_scale, + kargs.scale, + rp_undrop, + scale_rp_undrop, + smem_ptr, + dropout); + } } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp index bf38c3c07d..c3e84df934 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp" namespace ck_tile { @@ -14,12 +15,15 @@ template class BlockFmhaBwdDQDKDVPipelineSelector { static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV; + static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0; public: template using type_ = std::conditional_t, + std::conditional_t, + BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR>, std::conditional_t, BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>>; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp new file mode 100644 index 0000000000..65f70c4f62 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -0,0 +1,743 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR +{ + static constexpr auto is_qr_qtr_dor_pipeline = true; + + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using FmhaDropout = remove_cvref_t; + // using HotLoopScheduler = typename Policy::template HotLoopScheduler; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kUseTrLoad = Problem::kUseTrLoad; + static_assert(kUseTrLoad, "This pipeline uses trload!"); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = 1; + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = 1; + + static constexpr const char* name = "trload_kr_ktr_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking) + return (raw_lse == -numeric::infinity()) // + ? type_convert(0.f) + : raw_lse; + else + return raw_lse; + }; + + template + CK_TILE_DEVICE auto operator()( // + const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const KGradDramBlockWindowTmp& dk_dram_block_window_tmp, + const VGradDramBlockWindowTmp& dv_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + const QGradEpilogue& dq_epilogue, + const KGradEpilogue& dk_epilogue, + const VGradEpilogue& dv_epilogue, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); + + // Early termination + const auto [seqlen_kv_start, seqlen_kv_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_kv_end - seqlen_kv_start, kN0); + + // K, HBM ->LDS ->Reg + auto k_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + k_dram_block_window_tmp.get_bottom_tensor_view()), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_kv_start, 0}, + Policy::template MakeKDramTileDistribution()); + + // LDS allocation + const auto smem_ptr_ = + reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic + + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeK()); + + const auto do_lds_ptr = reinterpret_cast(smem_ptr_); + const auto q_lds_ptr = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad()); + const auto lse_lds_ptr = reinterpret_cast( // + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()); + const auto d_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE()); + + const auto ds_lds_ptr = + reinterpret_cast(smem_ptr_ + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeV()); + const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); + + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + v_dram_block_window_tmp.get_bottom_tensor_view()), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_kv_start, 0}, + Policy::template MakeVDramTileDistribution()); + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + + //------------------------------------------------------------------ + // KT, HBM -> LDS --trload-->Reg + + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_lds_read = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor()); + auto k_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegBlockDescriptor()); + + auto kt_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKTRegBlockDescriptor()); + + auto v_lds_read = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor()); + auto v_lds_read_window = + make_tile_window(v_lds_read, + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegBlockDescriptor()); + + //---------------------------- Loop Load in ----------------------------// + // Q: HBM -->LDS + auto q_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + q_dram_block_window_tmp.get_bottom_tensor_view()), + q_dram_block_window_tmp.get_window_lengths(), + {0, 0}, + Policy::template MakeQDramTileDistribution()); + + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsWriteBlockDescriptor()); + auto q_lds_write_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + auto q_lds_read = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsReadBlockDescriptor()); + auto q_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + q_lds_write_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + auto qt_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQTRegSliceBlockDescriptor()); + + // dO: HBM ->LDS ---load--> Reg + // dOT: \-loadtr-> Reg + auto do_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + do_dram_block_window_tmp.get_bottom_tensor_view()), + do_dram_block_window_tmp.get_window_lengths(), + {0, 0}, + Policy::template MakeOGradDramTileDistribution()); + + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsWriteBlockDescriptor()); + auto do_lds_write_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + auto do_lds_read = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsReadBlockDescriptor()); + auto do_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + do_lds_write_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + auto dot_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeOGradTRegSliceBlockDescriptor()); + + // dS: Reg -> Reg -> LDS + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // transform it to make it from col-major to row-major; prepared for load_tile_transpose + auto ds_lds_t = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_read_window = + make_tile_window(ds_lds_t, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeSGradRegSliceBlockDescriptor()); + + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + auto bias_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + bias_dram_block_window_tmp.get_bottom_tensor_view()), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_kv_start}, + Policy::template MakeBiasTileDistribution()); + + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + + auto bias_lds_read = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); + auto bias_s_lds_read_window = + make_tile_window(bias_lds_read, + make_tuple(number{}, number{}), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {0}, + Policy::template MakeLSEDDramTileDistribution()); + + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {0}, + Policy::template MakeLSEDDramTileDistribution()); + + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // RandVal: HBM ->Reg + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_kv_start); + + // BiasGrad + // Reg ->LDS ->Reg ->HBM + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + + auto dbias_dram_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {dbias_origin.at(number<0>{}), seqlen_kv_start}); // M/N + + auto dbias_lds_read_window = + make_tile_window(bias_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {0, 0}); + auto dk_dram_window = make_tile_window(dk_dram_block_window_tmp.get_bottom_tensor_view(), + dk_dram_block_window_tmp.get_window_lengths(), + {0, 0}); + auto dv_dram_window = make_tile_window(dv_dram_block_window_tmp.get_bottom_tensor_view(), + dv_dram_block_window_tmp.get_window_lengths(), + {0, 0}); + + index_t i_total_loops = 0; + index_t seqlen_kv_step = seqlen_kv_start; + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; + + __builtin_amdgcn_sched_barrier(0); + + decltype(load_tile(q_lds_read_window)) q_reg_tensor; + decltype(load_tile(lse_lds_read_window)) lse; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next; + decltype(load_tile(do_lds_read_window)) do_reg_tensor; + decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor; + decltype(load_tile(d_lds_read_window)) d; + decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor; + decltype(gemm_0.MakeCBlockTile()) s_acc, p; + decltype(gemm_2.MakeCBlockTile()) dp_acc, ds; + decltype(gemm_4.MakeCBlockTile()) dq_acc; + clear_tile(dq_acc); + + decltype(load_tile(lse_dram_window)) lse_block_tile; + decltype(load_tile(d_dram_window)) d_block_tile; + + async_load_tile(q_lds_write_window, q_dram_window); + async_load_tile(do_lds_write_window, do_dram_window); + __builtin_amdgcn_s_waitcnt(0); + qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + q_reg_tensor = load_tile(q_lds_read_window); + dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + do_reg_tensor = load_tile(do_lds_read_window); + + lse_block_tile = load_tile(lse_dram_window); + d_block_tile = load_tile(d_dram_window); + __builtin_amdgcn_s_waitcnt(0); + store_tile(lse_lds_write_window, lse_block_tile); + store_tile(d_lds_write_window, d_block_tile); + __builtin_amdgcn_s_waitcnt(0); + lse = load_tile(lse_lds_read_window); + d = load_tile(d_lds_read_window); + + auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { + constexpr bool is_prologue = is_prologue_.value; + constexpr bool is_epilogue = is_epilogue_.value; + static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true"); + constexpr bool is_main_body = is_prologue && is_epilogue; + + // init VGrad & KGrad + decltype(gemm_1.MakeCBlockTile()) dv_acc; + decltype(gemm_3.MakeCBlockTile()) dk_acc; + + decltype(load_tile(k_lds_read_window)) k_reg_tensor; + decltype(load_tile(v_lds_read_window)) v_reg_tensor; + decltype(load_tile_transpose(kt_lds_read_window)) kt_reg_tensor; + + if constexpr(is_epilogue) + { + async_load_tile(k_lds_write_window, k_dram_window); + move_tile_window(k_dram_window, {kN0, 0}); + async_load_tile(v_lds_write_window, v_dram_window); + move_tile_window(v_dram_window, {kN0, 0}); + // __builtin_amdgcn_s_waitcnt(0); + k_reg_tensor = load_tile(k_lds_read_window); + v_reg_tensor = load_tile(v_lds_read_window); + kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + } + if constexpr(is_epilogue) + { + // STAGE 1, Q@K Gemm0 + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm0(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + async_load_tile(bias_lds_write_window, bias_dram_window); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = tile_idx.at(number<0>{}); + const auto col = seqlen_kv_step + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + { + bool need_perpixel_check = + mask.IsEdgeTile(0, seqlen_kv_step, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = tile_idx.at(number<0>{}); + const auto col = seqlen_kv_step + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + else + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + 0, seqlen_kv_step, p, randval_dram_window); + } + const auto p_gemm = [&]() { // dropout / type conversion + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { + return type_convert(x > 0.f ? x : 0.f); + }, + p); + } + else + { + return cast_tile(p); + } + }(); + + // STAGE 4, OGrad@V Gemm2 + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + // STAGE 3, P^T@OGrad^T Gemm1 + auto pt_reg_tensor = make_static_distributed_tensor( + Policy::template MakePTRegSliceBlockDescriptor()); + pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); + + dv_acc = gemm_1(pt_reg_tensor, dot_reg_tensor); + } + block_sync_lds(); + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm12(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE 5, P^T(PGrad^T - D) + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + } + if constexpr(is_epilogue) + { + // STAGE 6, SGrad^T@Q^T Gemm3 + const auto ds_gemm = cast_tile(ds); + auto dst_reg_tensor = make_static_distributed_tensor( + Policy::template MakeSGradTRegSliceBlockDescriptor()); + dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); + dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor); + + store_tile(ds_lds_window, ds_gemm); + } + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + if constexpr(is_epilogue) + { + ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm3(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE7 SGrad@K^T Gemm4 + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + auto kt_reg_tensor_slice = get_slice_tile( // + kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + } + }); + move_tile_window(ds_lds_read_window, {-kN0, 0}); + } + block_sync_lds(); + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm4(); + if constexpr(is_epilogue) + { + // Results Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + + dk_epilogue(dk_dram_window, dk_acc); + move_tile_window(dk_dram_window, {kN0, 0}); + dv_epilogue(dv_dram_window, dv_acc); + move_tile_window(dv_dram_window, {kN0, 0}); + } + }; + + for(index_t i = 0; i < seqlen_kv_start; i += kN0) + { + dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}); + move_tile_window(dk_dram_window, {kN0, 0}); + dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}); + move_tile_window(dv_dram_window, {kN0, 0}); + } + + main_body(std::true_type{}, std::false_type{}); + // Hot loop + if(num_total_loop > 1) + { + do + { + main_body(std::true_type{}, std::true_type{}); + i_total_loops += 1; + seqlen_kv_step += kN0; + } while(i_total_loops < num_total_loop - 1); + } + main_body(std::false_type{}, std::true_type{}); + seqlen_kv_step += kN0; + + const auto k_length = k_dram_block_window_tmp.get_window_lengths(); + const auto seqlen_kv_length = k_length.at(number<0>{}); + for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0) + { + dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}); + move_tile_window(dk_dram_window, {kN0, 0}); + dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}); + move_tile_window(dv_dram_window, {kN0, 0}); + } + + // QGrad Scale + if constexpr(FmhaDropout::IsDropout) + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + else + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + // static_assert(kIsDeterministic); + dq_epilogue(dq_dram_window, dq_acc); + return; + } +}; + +template +concept fmha_bwd_qr_qtr_dor_pipeline_c = T::is_qr_qtr_dor_pipeline; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp index 6cef1db730..d1fb1669c9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -65,7 +65,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy typename Problem::BlockFmhaShape::Gemm2BlockWarps, typename Problem::BlockFmhaShape::Gemm2WarpTile>>; - using WarpGemm = WarpGemmMfmaDispatcher< + constexpr auto SwizzleA = false; + using WarpGemm = WarpGemmMfmaDispatcher< // typename Problem::OGradDataType, typename Problem::VDataType, typename Problem::AccDataType, @@ -73,7 +74,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}), false, - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + SwizzleA>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy>; - using WarpGemm = WarpGemmMfmaDispatcher{}), - BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), - BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), - false, - false, - false, - WGAttrNumAccessEnum::Double>; + using WarpGemm = WarpGemmMfmaDispatcher< // + typename Problem::GemmDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + BlockFmhaShape::Gemm4WarpTile::at(number<0>{}), + BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + false, + false, + false, + (Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}) == 32) + ? WGAttrNumAccessEnum ::Double + : WGAttrNumAccessEnum ::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy(); - constexpr index_t K1 = WarpAlignmentBytes / sizeof(T) / K2; - constexpr index_t K0 = ColsPerBlock / K1 / K2; - static_assert((K0 * K1 * K2 == ColsPerBlock) && K1 * K2 * sizeof(T) == WarpAlignmentBytes, + constexpr index_t K3 = GetAlignmentK(); // 8 + constexpr index_t K2 = WarpAlignmentBytes / sizeof(T) / K3; // 8 + constexpr index_t K_remain = ColsPerBlock / K2 / K3; + constexpr index_t K1 = min(kWarps, K_remain); + constexpr index_t K0 = K_remain / K1; + static_assert((K0 * K1 * K2 * K3 == ColsPerBlock) && + K2 * K3 * sizeof(T) == WarpAlignmentBytes, "ColsPerBlock notdivisible"); - constexpr index_t N2 = get_warp_size() / K1; - constexpr index_t N1 = kWarps / K0; + constexpr index_t N2 = get_warp_size() / K2; // 8 + constexpr index_t N1 = max(1, kWarps / K1); constexpr index_t N0 = RowsPerBlock / N1 / N2; - static_assert((N0 * N1 * N2 == RowsPerBlock) && (K0 * N1 == kWarps) && - (K1 * N2 == get_warp_size()), + static_assert((N0 * N1 * N2 == RowsPerBlock) && (K1 * N1 == kWarps) && + (K2 * N2 == get_warp_size()), "RowsPerBlock not divisible"); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, // K0 N1, N2 K1 - tuple, sequence<2, 1>>, - sequence<1, 2>, // N0 K2 - sequence<0, 2>>{}); + tuple, sequence>, + tuple, sequence<1, 2>>, // K1 N1, N2 K2 + tuple, sequence<2, 2>>, + sequence<1, 2, 2>, // N0 K0 K3 + sequence<0, 0, 3>>{}); } template @@ -961,13 +968,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t N1 = GetAlignmentBias(); + constexpr index_t N1 = min(static_cast(GetAlignmentBias()), + kMPerBlock * kNPerBlock / kBlockSize); constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t M2 = GetTransposedAlignmentBias(); - constexpr index_t M1 = get_warp_size() / N0; constexpr index_t M0 = kBlockSize / get_warp_size(); + constexpr index_t M1 = get_warp_size() / N0; + constexpr index_t M2 = kMPerBlock / M1 / M0; return make_static_tile_distribution( tile_distribution_encoding, diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 570cff8bf0..41a744ea91 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -74,7 +74,8 @@ template + typename Gemm4WarpTile_, + index_t kMaxSeqLenQ_ = 0> struct TileFmhaBwdShape { using BlockTile = remove_cvref_t; @@ -111,6 +112,10 @@ struct TileFmhaBwdShape // K/K^T at once static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline // that need load V at once + + static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_; + static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0, + "kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited"); }; } // namespace ck_tile