From 8cb8da53c96471d39b27d9564579107648dbb67a Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 12 Aug 2025 11:11:55 +0800 Subject: [PATCH] [CK_TILE] FMHA BWD Optimization For GFX950 (#2628) * simplify fmha_bwd_kernel MakeKargs & dq_dram_window * simply duplicate * trload pipeline * Try two-stage * add prefetch * optimize & iglp [ROCm/composable_kernel commit: 4fde1646e534415221edf81146d41f85fbf33e63] --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 96 +- example/ck_tile/01_fmha/fmha_bwd.hpp | 5 +- .../core/numeric/integral_constant.hpp | 12 +- .../ck_tile/core/tensor/tensor_adaptor.hpp | 22 +- .../ck_tile/core/tensor/tensor_descriptor.hpp | 28 +- include/ck_tile/host/device_prop.hpp | 6 + include/ck_tile/ops/fmha.hpp | 2 + .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 557 +------- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 4 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 10 +- ...ck_fmha_bwd_dq_dk_dv_pipeline_selector.hpp | 20 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 760 ++++++++++ ...block_fmha_bwd_pipeline_default_policy.hpp | 16 +- .../block_fmha_bwd_pipeline_problem.hpp | 2 + ...mha_bwd_pipeline_trload_default_policy.hpp | 1220 +++++++++++++++++ .../block/block_gemm_areg_breg_creg_v1.hpp | 42 +- 16 files changed, 2216 insertions(+), 586 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 47cf6b3ad4..8ca917cb6c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -83,6 +83,7 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< {F_deterministic}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, + {F_trload}, fmha_bwd_trait_{F_idx}>; using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline; @@ -113,7 +114,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dbias}, {F_dpad}, {F_dvpad}, - {F_deterministic}>; + {F_deterministic}, + {F_trload}>; #include @@ -168,29 +170,35 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) template <> float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + const bool has_load_tr = ck_tile::is_load_tr_supported(); float r = -1; {F_dispatch} return r; }} """ -FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} +FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_body} }} """ -FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_body} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_body} + }} """ -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; - r = fmha_bwd_(s, a); - return r; - }} +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; + r = fmha_bwd_(s, a); + return r; + }} """ # M0 size for 1d kernels (dot/convert) @@ -250,6 +258,7 @@ class FmhaBwdDQDKDVKernel: F_mode : str # value from MODE_MAP F_deterministic : str # mask_impl : str # + F_trload : str # @property def template(self) -> str: @@ -291,6 +300,7 @@ class FmhaBwdDQDKDVKernel: F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_mode = MODE_MAP[self.F_mode], F_deterministic = BOOL_MAP[self.F_deterministic], + F_trload = BOOL_MAP[self.F_trload], ) @property @@ -324,6 +334,9 @@ class FmhaBwdDQDKDVKernel: if self.F_deterministic == 't' : n += '_deterministic' else: n += '_ndeterministic' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' return n @property @@ -332,8 +345,8 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. -def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]: + if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': return { '32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), '64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), @@ -341,6 +354,10 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict # '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), '256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), } + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': + return { + '128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + } else: return None @@ -573,6 +590,7 @@ class FmhaBwdApiTrait: dvpad : str deterministic : str mask_impl : str + tr_load : bool @property def bm0(self) -> int: @@ -620,7 +638,7 @@ class FmhaBwdApiTrait: def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, - F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl) + F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load) @property def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: @@ -636,12 +654,13 @@ class FmhaBwdApiTrait: class FmhaBwdApiPool: def __init__(self, mask_impl): - self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(list)) + self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: # TODO: do we need to check duplication? - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait)) @staticmethod def if_(i: int) -> str: @@ -656,24 +675,31 @@ class FmhaBwdApiPool: F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) + F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load]) i += 1 return inners @property def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype]): - traits=self.dq_dk_dv_pool[dtype][hdim] - inners = self._api_innders(traits) - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(j), F_hdim=hdim, F_inner_dispatch=inners) - per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(i), F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + per_tr_load = '' + for tr_load in ["t", "f"]: + per_dtypes = '' + for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]): + per_hdim_case = '' + for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]): + traits = self.dq_dk_dv_pool[tr_load][dtype][hdim] + inners = self._api_innders(traits) + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners) + per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case) + per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes) + if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: if filter_list == '': @@ -690,8 +716,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} api_pool = FmhaBwdApiPool(mask_impl) - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load) if d is None: continue for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): @@ -703,7 +729,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ("wg32" in dropout): continue - t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) + if tr_load == "t" and (dpad == "t" or dvpad == "t"): + continue # tr_load cannot work with dpad or dvpad + t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index c999cf750e..bd63c96eb1 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" @@ -363,7 +364,8 @@ template + bool kIsDeterministic_, + bool kUseTrLoad_> struct fmha_bwd_dq_dk_dv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -376,6 +378,7 @@ struct fmha_bwd_dq_dk_dv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr bool kUseTrLoad = kUseTrLoad_; }; template diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index 2ba2fd10c6..1eec80828a 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -83,4 +83,14 @@ CK_TILE_BINARY_OP(<=) #undef CK_TILE_LEFT_UNARY_OP #undef CK_TILE_BINARY_OP +template +struct is_constant : std::false_type +{ +}; +template +struct is_constant> : std::true_type +{ +}; +template +inline constexpr bool is_constant_v = is_constant::value; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index ec5538d79c..eb226debfd 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -259,6 +259,7 @@ struct tensor_adaptor CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); } + template CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides( const array& guaranteed_vector_lengths, const array& guaranteed_vector_strides) @@ -266,7 +267,9 @@ struct tensor_adaptor auto vector_lengths = guaranteed_vector_lengths; auto vector_strides = guaranteed_vector_strides; - static_for<0, get_num_of_transform(), 1>{}([&](auto itran) { + static_for<0, + Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(), + 1>{}([&](auto itran) { constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran); constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran); @@ -298,11 +301,16 @@ struct tensor_adaptor set_container_subset(vector_lengths, up_dims, up_vector_lengths); set_container_subset(vector_strides, up_dims, up_vector_strides); }); - - constexpr auto top_dims = TopDimensionHiddenIds{}; - - return make_tuple(get_container_subset(vector_lengths, top_dims), - get_container_subset(vector_strides, top_dims)); + if constexpr(Internal > 0) + { + return make_tuple(vector_lengths, vector_strides); + } + else + { + constexpr auto top_dims = TopDimensionHiddenIds{}; + return make_tuple(get_container_subset(vector_lengths, top_dims), + get_container_subset(vector_strides, top_dims)); + } } private: diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index 0e4787a2f1..3b372d45dd 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -133,9 +133,10 @@ struct tensor_descriptor : public tensor_adaptor CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides() { - return Base::get_top_dimension_safe_vector_length_strides( + return Base::template get_top_dimension_safe_vector_length_strides( to_array(GuaranteedVectorLengths{}), to_array(GuaranteedVectorStrides{})); } @@ -377,12 +378,29 @@ make_naive_tensor_descriptor_packed(const tuple& lengths, const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{}); + constexpr index_t first_dim_length = []() { + if constexpr(is_constant_v>) + return decltype(element_space_size)::value; + else + return -1; + }(); + using last_t = remove_cvref_t())>; + constexpr index_t last_dim_length = []() { + if constexpr(is_constant_v) + return std::max(last_t::value, GuaranteedLastDimensionVectorLength); + else + return -1; + }(); + using GuaranteedVectorLengths = - typename sequence_merge::type, - sequence>::type; + typename sequence_merge, + typename uniform_sequence_gen::type, + sequence>::type; using GuaranteedVectorStrides = - typename sequence_merge::type, sequence<1>>::type; + typename sequence_merge, + typename uniform_sequence_gen::type, + sequence<1>>::type; return tensor_descriptor, remove_cv_t, diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index d33b298369..0d8f89ea31 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -51,6 +51,12 @@ inline std::string get_device_name() default: return name; } } + +inline bool is_load_tr_supported() +{ + // Check if load transpose is supported. + return get_device_name() == "gfx950"; +} } // namespace ck_tile #endif diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 313de5f29a..276ec4852f 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -25,8 +25,10 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 8b184b18f3..595e2cfccf 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -62,6 +62,12 @@ struct FmhaBwdDQDKDVKernel static constexpr bool kHasDropout = FmhaDropout::IsDropout; static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval; static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic; + static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad; +#if defined(__gfx950__) + static constexpr bool kIsAvialable = true; +#else + static constexpr bool kIsAvialable = !kUseTrLoad; +#endif // clang-format off template struct t2s; @@ -99,7 +105,7 @@ struct FmhaBwdDQDKDVKernel ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + - (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ); + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ // clang-format on @@ -298,6 +304,24 @@ struct FmhaBwdDQDKDVKernel using Kargs = std::conditional_t; + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr Kargs + MakeKargs(Ts... args, const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr Kargs + MakeKargs(Ts... args, const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargsImpl(const void* q_ptr, @@ -466,248 +490,6 @@ struct FmhaBwdDQDKDVKernel return kargs; } - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_do, - ck_tile::index_t batch_stride_lsed, - ck_tile::index_t batch_stride_dq_acc, - ck_tile::index_t batch_stride_dk, - ck_tile::index_t batch_stride_dv, - ck_tile::index_t batch_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, - batch_stride_dk, - batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_do, - ck_tile::index_t batch_stride_lsed, - ck_tile::index_t batch_stride_dq_acc, - ck_tile::index_t batch_stride_dk, - ck_tile::index_t batch_stride_dv, - ck_tile::index_t batch_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, - batch_stride_dk, - batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - template CK_TILE_HOST static constexpr std::enable_if_t MakeKargsImpl(const void* q_ptr, @@ -854,208 +636,6 @@ struct FmhaBwdDQDKDVKernel return kargs; } - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) { @@ -1082,6 +662,12 @@ struct FmhaBwdDQDKDVKernel } CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if constexpr(kIsAvialable) + run_(std::move(kargs)); + } + + CK_TILE_DEVICE void run_(Kargs kargs) const { // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; @@ -1282,62 +868,33 @@ struct FmhaBwdDQDKDVKernel {0, 0}); auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() { - if constexpr(kIsDeterministic) - { - AccDataType* dq_acc_ptr = - reinterpret_cast(kargs.dq_acc_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + - static_cast(i_tile_n_) * kargs.split_stride_dq_acc + - batch_offset_dq_acc; + AccDataType* dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + [&]() { + if constexpr(kIsDeterministic) + return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + + static_cast(i_tile_n_) * kargs.split_stride_dq_acc + + batch_offset_dq_acc; + else + return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + + batch_offset_dq_acc; + }(); - auto dq_acc_dram = [&]() { - const auto dq_acc_dram_naive = - make_naive_tensor_view( - dq_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - return make_tile_window( - dq_acc_dram, - make_tuple(number{}, number{}), - {0, 0}); - } - else - { - AccDataType* dq_acc_ptr = - reinterpret_cast(kargs.dq_acc_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + - batch_offset_dq_acc; - - auto dq_acc_dram = [&]() { - const auto dq_acc_dram_naive = - make_naive_tensor_view( - dq_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - return make_tile_window( - dq_acc_dram, - make_tuple(number{}, number{}), - {0, 0}); - } + constexpr auto DstInMemOp = conditional_expr( + memory_operation_enum::set, memory_operation_enum::atomic_add); + const auto dq_acc_dram_naive = + make_naive_tensor_view( + dq_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + const auto dq_acc_dram = pad_tensor_view( + dq_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + dq_acc_dram, + make_tuple(number{}, number{}), + {0, 0}); }(); auto lse_dram_window = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 1f11569533..d36f8ad724 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kUseTrLoad = Problem::kUseTrLoad; + static_assert(!kUseTrLoad, "This pipeline does not use trload!"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 967fe2362d..88fb1281aa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kUseTrLoad = Problem::kUseTrLoad; + static_assert(!kUseTrLoad, "This pipeline does not use trload!"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -654,9 +656,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }(); // STAGE 3, P^T@OGrad^T Gemm1 - Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, p_gemm); + Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, p_gemm); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); auto qt_reg_tensor = load_tile(qt_lds_read_window); @@ -728,9 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // STAGE 6, SGrad^T@Q^T Gemm3 const auto ds_gemm = cast_tile(ds); - Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp index 80c311de86..bf38c3c07d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -6,22 +6,30 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp" namespace ck_tile { -template +template class BlockFmhaBwdDQDKDVPipelineSelector { static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV; public: - using type = std::conditional_t, - BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>; + template + using type_ = + std::conditional_t, + std::conditional_t, + BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>>; + using type = std::conditional_t, // + type_, + type_>; }; -template -class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector::type +template +class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector::type { public: static constexpr const char* name = "auto"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp new file mode 100644 index 0000000000..1d95bc2801 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -0,0 +1,760 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using FmhaDropout = remove_cvref_t; + // using HotLoopScheduler = typename Policy::template HotLoopScheduler; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kUseTrLoad = Problem::kUseTrLoad; + static_assert(kUseTrLoad, "This pipeline uses trload!"); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = 1; + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = 1; + + static constexpr const char* name = "trload_kr_ktr_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking) + return (raw_lse == -numeric::infinity()) // + ? type_convert(0.f) + : raw_lse; + else + return raw_lse; + }; + + template + CK_TILE_DEVICE auto operator()( // + const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + // K, HBM ->LDS ->Reg + auto k_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + k_dram_block_window_tmp.get_bottom_tensor_view()), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + const auto k_origin = k_dram_window.get_window_origin(); + + // Early termination + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return make_tuple(dk_acc, dv_acc); + } + } + + // LDS allocation + const auto smem_ptr_ = + reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic + + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeK()); + + const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); + const auto do_lds_ptr1 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr0 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr1 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()); + const auto lse_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); + const auto d_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE()); + const auto ds_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); + const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); + + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + v_dram_block_window_tmp.get_bottom_tensor_view()), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVDramTileDistribution()); + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + + //------------------------------------------------------------------ + // KT, HBM -> LDS --trload-->Reg + async_load_tile(k_lds_write_window, k_dram_window); + async_load_tile(v_lds_write_window, v_dram_window); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_lds_read = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor()); + auto k_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegBlockDescriptor()); + auto k_reg_tensor = load_tile(k_lds_read_window); + + auto kt_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKTRegBlockDescriptor()); + + auto kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + + auto v_lds_read = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor()); + auto v_lds_read_window = + make_tile_window(v_lds_read, + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegBlockDescriptor()); + auto v_reg_tensor = load_tile(v_lds_read_window); + + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + //---------------------------- Loop Load in ----------------------------// + // Q: HBM -->LDS + auto q_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + q_dram_block_window_tmp.get_bottom_tensor_view()), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeQDramTileDistribution()); + + auto q_lds = make_tensor_view( + q_lds_ptr0, Policy::template MakeQLdsWriteBlockDescriptor()); + auto q_lds_write_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + auto q_lds_read = make_tensor_view( + q_lds_ptr0, Policy::template MakeQLdsReadBlockDescriptor()); + auto q_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + q_lds_write_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + auto qt_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQTRegSliceBlockDescriptor()); + + // dO: HBM ->LDS ---load--> Reg + // dOT: \-loadtr-> Reg + auto do_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + do_dram_block_window_tmp.get_bottom_tensor_view()), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeOGradDramTileDistribution()); + + auto do_lds = make_tensor_view( + do_lds_ptr0, Policy::template MakeOGradLdsWriteBlockDescriptor()); + auto do_lds_write_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + auto do_lds_read = make_tensor_view( + do_lds_ptr0, Policy::template MakeOGradLdsReadBlockDescriptor()); + auto do_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + do_lds_write_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + auto dot_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeOGradTRegSliceBlockDescriptor()); + + // dS: Reg -> Reg -> LDS + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // transform it to make it from col-major to row-major; prepared for load_tile_transpose + auto ds_lds_t = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_read_window = + make_tile_window(ds_lds_t, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeSGradRegSliceBlockDescriptor()); + + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + auto bias_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + bias_dram_block_window_tmp.get_bottom_tensor_view()), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}, + Policy::template MakeBiasTileDistribution()); + + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + + auto bias_lds_read = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); + auto bias_s_lds_read_window = + make_tile_window(bias_lds_read, + make_tuple(number{}, number{}), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // RandVal: HBM ->Reg + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + // BiasGrad + // Reg ->LDS ->Reg ->HBM + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + + auto dbias_dram_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto dbias_lds_read_window = + make_tile_window(bias_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + index_t i_total_loops = 0; + index_t seqlen_q_step = seqlen_q_start; + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + __builtin_amdgcn_sched_barrier(0); + + decltype(load_tile(q_lds_read_window)) q_reg_tensor; + decltype(load_tile(lse_lds_read_window)) lse; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next; + decltype(load_tile(do_lds_read_window)) do_reg_tensor; + decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor; + decltype(load_tile(d_lds_read_window)) d; + decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor; + decltype(gemm_0.MakeCBlockTile()) s_acc, p; + decltype(gemm_2.MakeCBlockTile()) dp_acc, ds; + decltype(gemm_4.MakeCBlockTile()) dq_acc; + + decltype(load_tile(lse_dram_window)) lse_block_tile; + decltype(load_tile(d_dram_window)) d_block_tile; + + index_t i_total_bodys = 0; + auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { + const bool is_even = (i_total_bodys % 2 == 0); + QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; + QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; + OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; + OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; + + constexpr bool is_prologue = is_prologue_.value; + constexpr bool is_epilogue = is_epilogue_.value; + static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true"); + constexpr bool is_main_body = is_prologue && is_epilogue; + + if constexpr(is_prologue) + { + q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); + async_load_tile(q_lds_write_window, q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + + lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); + async_load_tile(do_lds_write_window, do_dram_window); + move_tile_window(do_dram_window, {kM0, 0}); + + d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + } + if constexpr(is_epilogue) + { + // STAGE 1, Q@K Gemm0 + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); + dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm0(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + async_load_tile(bias_lds_write_window, bias_dram_window); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + { + bool need_perpixel_check = mask.IsEdgeTile( + seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + else + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); + } + const auto p_gemm = [&]() { // dropout / type conversion + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { + return type_convert(x > 0.f ? x : 0.f); + }, + p); + } + else + { + return cast_tile(p); + } + }(); + + // STAGE 4, OGrad@V Gemm2 + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); + qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + + // STAGE 3, P^T@OGrad^T Gemm1 + auto pt_reg_tensor = make_static_distributed_tensor( + Policy::template MakePTRegSliceBlockDescriptor()); + pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); + } + block_sync_lds(); + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm12(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_prologue) + { + store_tile(lse_lds_write_window, lse_block_tile); + store_tile(d_lds_write_window, d_block_tile); + } + if constexpr(is_epilogue) + { + // STAGE 5, P^T(PGrad^T - D) + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + } + if constexpr(is_epilogue) + { + // STAGE 6, SGrad^T@Q^T Gemm3 + const auto ds_gemm = cast_tile(ds); + auto dst_reg_tensor = make_static_distributed_tensor( + Policy::template MakeSGradTRegSliceBlockDescriptor()); + dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); + gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + store_tile(ds_lds_window, ds_gemm); + } + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + if constexpr(is_prologue) + { + q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); + q_reg_tensor = load_tile(q_lds_read_window); + lse = load_tile(lse_lds_read_window); + } + if constexpr(is_epilogue) + { + ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm3(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE7 SGrad@K^T Gemm4 + clear_tile(dq_acc); + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + auto kt_reg_tensor_slice = get_slice_tile( // + kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + } + }); + move_tile_window(ds_lds_read_window, {-kN0, 0}); + } + block_sync_lds(); + if constexpr(is_prologue) + { + do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); + do_reg_tensor = load_tile(do_lds_read_window); + d = load_tile(d_lds_read_window); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm4(); + if constexpr(is_epilogue) + { + // QGrad Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + if constexpr(kIsDeterministic) + { + store_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, dq_acc); + } + move_tile_window(dq_dram_window, {kM0, 0}); + } + i_total_bodys += 1; + }; + + main_body(std::true_type{}, std::false_type{}); + // Hot loop + if(num_total_loop > 1) + { + do + { + main_body(std::true_type{}, std::true_type{}); + i_total_loops += 1; + seqlen_q_step += kM0; + } while(i_total_loops < num_total_loop - 1); + } + main_body(std::false_type{}, std::true_type{}); + + // Results Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + + return make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 521968a43b..aa2ec99590 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -64,7 +64,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + CK_TILE_DEVICE static constexpr auto GetPTOGradTBlockGemm() { using GemmProblem = BlockGemmProblem{}), Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true>; + true, + false, // SwizzleAccess + false, // UseStructuredSparsity + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) + ? WGAttrNumAccessEnum ::Double + : WGAttrNumAccessEnum ::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy{}), Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), - true>; + true, + false, // SwizzleAccess + false, // UseStructuredSparsity + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) + ? WGAttrNumAccessEnum ::Double + : WGAttrNumAccessEnum ::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy struct BlockFmhaBwdPipelineProblem { @@ -53,6 +54,7 @@ struct BlockFmhaBwdPipelineProblem static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp new file mode 100644 index 0000000000..6cef1db730 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -0,0 +1,1220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +#include "ck_tile/core/utility/debug.hpp" + +namespace ck_tile { + +struct BlockFmhaBwdPipelineTrLoadDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + constexpr auto SwizzleA = false; + using WarpGemm = WarpGemmMfmaDispatcher< // + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + false, + SwizzleA>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + { + return BlockFmhaBwdPipelineDefaultPolicy::GetPTOGradTBlockGemm(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>>; + + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::OGradDataType, + typename Problem::VDataType, + typename Problem::AccDataType, + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}), + false, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() + { + return BlockFmhaBwdPipelineDefaultPolicy::GetSGradTQTBlockGemm(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() + { + using BlockFmhaShape = typename Problem::BlockFmhaShape; + using GemmProblem = BlockGemmProblem< + typename Problem::GemmDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + Problem::kBlockSize, + TileGemmShape< + sequence, + typename BlockFmhaShape::Gemm4BlockWarps, + typename BlockFmhaShape::Gemm4WarpTile>>; + + using WarpGemm = WarpGemmMfmaDispatcher{}), + BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + false, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + // these are for global load + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentX() noexcept + { + return 16 / sizeof(T); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() + { + return GetAlignmentX(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad() + { + return GetAlignmentX(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentVGrad() + { + return GetAlignmentX(); + } + + // these are for load_tr_b64 + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentX() noexcept + { + return 8 / sizeof(T); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() noexcept + { + return GetTransposedAlignmentX(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + return total_pixels / GetAlignmentOGrad(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + + return total_pixels / GetAlignmentBias(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGradAcc() + { + using AccDataType = remove_cvref_t; + return 16 / sizeof(AccDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGrad() + { + return GetAlignmentPostQGradAcc(); + } + + // It is found that alignment of 8x dwordx4 can avoid bank conflicts for both transposed and + // non-transposed load + static constexpr index_t WarpAlignmentBytes = 128; + + // As load_lds requires contiguous LDS write, we need to transform the distribution of DRAM for + // reading + template + CK_TILE_HOST_DEVICE static constexpr auto TransformXDramTensorView(const TensorView& naive_view) + { + if constexpr(std::is_same_v) + { + return naive_view; + } + else + { + const auto transformed_desc = + TransformXDramDescriptor(naive_view.get_tensor_descriptor()); + return tensor_view, + TensorView::DstInMemOp>{naive_view.buf_, transformed_desc}; + } + } + template + CK_TILE_HOST_DEVICE static constexpr auto + TransformXDramDescriptor(const tensor_descriptor& from_desc) + { + using from_desc_t = tensor_descriptor; + + constexpr auto ndims = from_desc_t::get_num_of_dimension(); + static_assert(ndims == 2, "XDram descriptor must have 2 dimensions"); + const auto Rows = from_desc.get_length(number<0>{}); + // constexpr auto Cols = 128; + // assert(from_desc.get_length(number<1>{}) == 128); + const auto Cols = from_desc.get_length(number<1>{}); + + constexpr index_t Dwordx4Bytes = 16; + constexpr index_t K2 = Dwordx4Bytes / sizeof(T); + constexpr index_t K1 = WarpAlignmentBytes / Dwordx4Bytes; + const index_t K0 = Cols / K1; + const auto ColLens = make_tuple(K0, number{}, number{}); + + const auto desc_tmp1 = transform_tensor_descriptor( + from_desc, + make_tuple(make_pass_through_transform(Rows), make_unmerge_transform(ColLens)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{})); + + const auto desc_tmp2 = transform_tensor_descriptor( + desc_tmp1, + make_tuple(make_xor_transform(make_tuple(Rows, number{})), + make_pass_through_transform(K0), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + return transform_tensor_descriptor( + desc_tmp2, + make_tuple(make_pass_through_transform(Rows), + make_merge_transform_v3_division_mod(ColLens)), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kWarps = kBlockSize / get_warp_size(); + + constexpr index_t K2 = GetAlignmentK(); + constexpr index_t K1 = WarpAlignmentBytes / sizeof(T) / K2; + constexpr index_t K0 = ColsPerBlock / K1 / K2; + static_assert((K0 * K1 * K2 == ColsPerBlock) && K1 * K2 * sizeof(T) == WarpAlignmentBytes, + "ColsPerBlock notdivisible"); + + constexpr index_t N2 = get_warp_size() / K1; + constexpr index_t N1 = kWarps / K0; + constexpr index_t N0 = RowsPerBlock / N1 / N2; + static_assert((N0 * N1 * N2 == RowsPerBlock) && (K0 * N1 == kWarps) && + (K1 * N2 == get_warp_size()), + "RowsPerBlock not divisible"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, // K0 N1, N2 K1 + tuple, sequence<2, 1>>, + sequence<1, 2>, // N0 K2 + sequence<0, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() + { + constexpr index_t K1 = 16 / sizeof(DataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = 1; + constexpr index_t M1 = get_warp_size(); + constexpr index_t M0 = MPerBlock / M1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1>>, + tuple, sequence<1>>, + sequence<1, 2, 2>, + sequence<2, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() + { + using ODataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() + { + using OGradDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution() + { + using AccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kKPerBlock = Problem::kQKHeaddim; + + constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K0 = kKPerBlock / K1; + + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M1 * M2); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence, sequence>, + tuple, sequence<2, 3>>, + tuple, sequence<2, 0>>, + sequence<1, 2, 3>, + sequence<0, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution() + { + using AccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kKPerBlock = Problem::kQKHeaddim; + + constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K0 = kKPerBlock / K1; + + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M1 * M2); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeKRegBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeVRegBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto kt_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, // 2 4, 4 + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + auto output = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(kt_block_dstr_encode), + typename Problem::KDataType>::TransposedDstrEncode{}); + return output; + } + + // lds write descriptor used together with block_sync_lds (transformed dram descriptor) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsWriteBlockDescriptor() + { + constexpr index_t KPack = WarpAlignmentBytes / sizeof(T); + + constexpr auto desc_0 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number{}, number{})); + return transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() + { + // SGrad should be of the same distr as Gemm2 OGradV's output (i.e. PGrad) + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t M2 = WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane; + constexpr index_t M1 = WarpGemm::WarpGemmAttribute::Impl::kCMLane; + static_assert(WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane == 1, "kCM0PerLane must be 1"); + constexpr index_t M0 = kMPerBlock / (M1 * M2); + + constexpr index_t N1 = WarpGemm::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = kNPerBlock / N1; + + constexpr auto desc_0 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number{}, number{}, number{}, number{})); + + constexpr index_t M1_0 = 2, M1_1 = 2; + constexpr index_t N1_0 = 2, N1_1 = 8; + static_assert(M1_0 * M1_1 == M1, "M1_0 * M1_1 must equal M1"); + static_assert(N1_0 * N1_1 == N1, "N1_0 * N1_1 must equal N1"); + + constexpr auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4, 5>{}, sequence<6>{})); + constexpr auto desc_2 = transform_tensor_descriptor( + desc_1, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 4>{}, + sequence<3>{}, + sequence<5>{}, + sequence<6>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 4>{}, + sequence<3>{}, + sequence<5>{}, + sequence<6>{})); + + constexpr auto top_dims = []() { + if constexpr(Transposed) + return make_tuple(sequence<1>{}, sequence<0>{}); + else + return make_tuple(sequence<0>{}, sequence<1>{}); + }(); + return transform_tensor_descriptor( + desc_2, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 2, 3, 6>{}, sequence<1, 4, 5>{}), + top_dims); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsReadBlockDescriptor() + { + const auto Dwordx4Bytes = 16; + const auto K2 = Dwordx4Bytes / sizeof(T); + const auto K1 = WarpAlignmentBytes / Dwordx4Bytes; + const auto K0 = KPerBlock / (K1 * K2); + + constexpr auto desc_0 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number{}, number{}, number{})); + constexpr auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + return transform_tensor_descriptor( + desc_1, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto q_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto qt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(qt_block_dstr_encode), + typename Problem::QDataType>::TransposedDstrEncode{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto dst_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode); + + return dst_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsWriteBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + using LSEDType = remove_cvref_t; + constexpr index_t kMPack = 16 / sizeof(LSEDType); + + constexpr auto lsed_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}), + make_tuple(number<1>{}), + number{}, + number<1>{}); + + return lsed_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = NWarp; + + // M4 *2 and M2 /2 when swizzle mode enabled + constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2; + // constexpr index_t SwizzleConfig = 1; + constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig; + constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<1, 0>>, + tuple, sequence<3, 1>>, + sequence<1, 1, 1>, + sequence<0, 2, 4>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto do_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode); + + return do_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + // constexpr index_t kNPerBlock = 32; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto dot_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + // CK_PRINT(); + // CK_PRINT(); + + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(dot_block_dstr_encode), + typename Problem::OGradDataType>::TransposedDstrEncode{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto pt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode); + + return pt_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto ds_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(ds_block_dstr_encode), + typename Problem::GemmDataType>::TransposedDstrEncode{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M2 = GetTransposedAlignmentBias(); + constexpr index_t M1 = get_warp_size() / N0; + constexpr index_t M0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasSTileDistribution() + { + using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); + return c_block_tensor_type::get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQ() + { + return sizeof(typename Problem::QDataType) * + MakeQLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeK() + { + return sizeof(typename Problem::KDataType) * + MakeKLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE() + { + return sizeof(typename Problem::LSEDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD() + { + return sizeof(typename Problem::DDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeV() + { + return sizeof(typename Problem::VDataType) * + MakeVLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGrad() + { + return sizeof(typename Problem::OGradDataType) * + MakeOGradLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeSGrad() + { + return sizeof(typename Problem::GemmDataType) * + MakeSGradLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeBias() + { + if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return sizeof(typename Problem::BiasDataType) * + MakeBiasLdsWriteBlockDescriptor().get_element_space_size(); + else + return 0; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_q = GetSmemSizeQ(); + constexpr index_t smem_size_lse = GetSmemSizeLSE(); + constexpr index_t smem_size_k = GetSmemSizeK(); + constexpr index_t smem_size_v = GetSmemSizeV(); + constexpr index_t smem_size_do = GetSmemSizeOGrad(); + constexpr index_t smem_size_d = GetSmemSizeD(); + constexpr index_t smem_size_ds = GetSmemSizeSGrad(); + constexpr index_t smem_size_bias = GetSmemSizeBias(); + + constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v; + constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse + + smem_size_d + max(smem_size_bias, smem_size_ds); + return max(smem_size_stage0, smem_size_stage1); + } + + template + class HotLoopScheduler + { + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; + static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; + static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0; + static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2; + static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; + + static constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static constexpr index_t WarpGemmN = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}); + static constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); + static constexpr index_t Gemm4MWarp = + Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + static constexpr index_t Gemm4NWarp = + Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + static constexpr index_t blockWarps = kBlockSize / get_warp_size(); + using GemmDataType = typename Problem::GemmDataType; + + // Compute + static constexpr index_t Gemm0MFMA = + kM0 * kN0 * kK0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm1MFMA = + kN0 * kVHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm2MFMA = + kM0 * kN0 * kK2 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm3MFMA = + kN0 * kQKHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm4MFMA = + kM0 * kQKHeaddim * kN0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + + // VMEM + static constexpr index_t Q_VMEM_READ = + kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t OGrad_VMEM_READ = + kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + static constexpr index_t LSE_VMEM_READ = 1; + static constexpr index_t D_VMEM_READ = 1; + + // LDS Read + static constexpr index_t OGradT_LDS_READ = + kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); + static constexpr index_t QT_LDS_READ = + kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); + static constexpr index_t SGradT_LDS_READ_P1 = + kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetTransposedAlignmentX(); + static constexpr index_t SGradT_LDS_READ_P2 = + kM0 * kN0 / (get_warp_size() * Gemm4MWarp) / GetTransposedAlignmentX() - + SGradT_LDS_READ_P1; + static constexpr index_t Q_LDS_READ = + kM0 * kK0 / get_warp_size() / GetAlignmentQ(); + static constexpr index_t LSE_LDS_READ = kM0 / (4 * 4); + static constexpr index_t D_LDS_READ = LSE_LDS_READ; + static constexpr index_t OGrad_LDS_READ = + kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); + + // LDS Write + static constexpr index_t Q_LDS_WRITE = + kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ(); + static constexpr index_t QT_LDS_WRITE = + kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ(); + static constexpr index_t OGrad_LDS_WRITE = + kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + static constexpr index_t OGradT_LDS_WRITE = + kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad(); + static constexpr index_t LSE_LDS_WRITE = 1; + static constexpr index_t D_LDS_WRITE = 1; + static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize; + + public: + CK_TILE_DEVICE static constexpr void SchedulerGemm0() + { + // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load + // Comp: Q x K + constexpr index_t VMEM_READ_INST = + Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ; + constexpr index_t MFMA_INST = Gemm0MFMA; + constexpr index_t LDS_READ_INST = OGradT_LDS_READ; + + constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST); + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / VMEM_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / LDS_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + } + + CK_TILE_DEVICE static constexpr void SchedulerGemm12() + { + // Mem: Q^T LDS load + // Comp: PT x OGrad + constexpr index_t LDS_READ_INST = QT_LDS_READ; + constexpr index_t MFMA_INST = Gemm1MFMA + Gemm2MFMA; + + constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST); + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / LDS_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // VMEM read + }); + } + + CK_TILE_DEVICE static constexpr void SchedulerGemm3() + { + // Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load. + // Comp: SGradT x QT + constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE; + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ; + constexpr index_t MFMA_INST = Gemm3MFMA; + + constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST; + constexpr index_t lcm_inst = lcm(MFMA_INST, lds_rw_inst); + + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / lds_rw_inst) == 0) + { + if constexpr(i / (lcm_inst / lds_rw_inst) < LDS_WRITE_INST) + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + else + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS Read + } + }); + } + + CK_TILE_DEVICE static constexpr void SchedulerGemm4() + { + // Mem: SGrad, OGrad, D LDS load. + // Comp: SGrad x KT + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ; + constexpr index_t MFMA_INST = Gemm4MFMA; + + constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST); + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / LDS_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + } + }; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 28d8b3eead..4652e5f20f 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,9 @@ namespace ck_tile { // A is block distributed tensor // B is block distributed tensor // C is block distributed tensor -template +template struct BlockGemmARegBRegCRegV1 { private: @@ -44,8 +46,9 @@ struct BlockGemmARegBRegCRegV1 }; public: - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + static constexpr bool TransposeC = TransposeC_; using Traits = GemmTraits_; @@ -131,6 +134,7 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { + using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; if constexpr(UseDefaultScheduler) { constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< @@ -138,7 +142,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple<>, tuple<>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); @@ -152,7 +156,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); @@ -172,25 +176,19 @@ struct BlockGemmARegBRegCRegV1 std::is_same_v>, "wrong!"); - constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode(); - - constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode(); - - constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode(); - // check ABC-block-distribution static_assert( - std::is_same_v, + std::is_same_v, remove_cvref_t>, "A distribution is wrong!"); static_assert( - std::is_same_v, + std::is_same_v, remove_cvref_t>, "B distribution is wrong!"); static_assert( - std::is_same_v, + std::is_same_v, remove_cvref_t>, "C distribution is wrong!"); @@ -219,7 +217,6 @@ struct BlockGemmARegBRegCRegV1 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A Block window AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); @@ -227,16 +224,16 @@ struct BlockGemmARegBRegCRegV1 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor + using c_iter_idx = std:: + conditional_t, sequence>; CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM @@ -244,7 +241,7 @@ struct BlockGemmARegBRegCRegV1 // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); }); @@ -254,6 +251,7 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { + using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; if constexpr(UseDefaultScheduler) { constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< @@ -261,7 +259,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple<>, tuple<>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -277,7 +275,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(