diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index e9ae11fb5f..79fe6492a6 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -139,6 +139,7 @@ LAYOUT_MAP = {"row": "true", "col": "false"} PIPELINE_MAP = { "qr": "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad", "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", @@ -147,6 +148,7 @@ PIPELINE_MAP = { PIPELINE_ENUM_MAP = { "qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD", "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 42e2d1f487..c64a19104e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ +FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#if defined(__HIP_DEVICE_COMPILE__) && \ + (defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \ + defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)) +#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK) +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 +#endif +#endif +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + FMHA_FWD_KERNEL_BODY_TEMPLATE = """ #include @@ -300,7 +316,7 @@ class FmhaFwdApiTrait: return "true" # always support else: return "true" - elif self.pipeline_tag in ["qr", "qs"]: + elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]: if self.spad == "t": return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: @@ -323,7 +339,7 @@ class FmhaFwdApiTrait: return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag in ["qr", "qs"]: + elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]: if self.skpad == "t": return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: @@ -344,6 +360,11 @@ class FmhaFwdApiTrait: return f"a.hdim_q % {vec} == 0" else: assert False + elif self.pipeline_tag == "qr_hpad": + if self.dpad == "t": + return "a.hdim_q % 8 == 0" + else: + assert False elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": @@ -361,6 +382,11 @@ class FmhaFwdApiTrait: return f"a.hdim_v % {vec} == 0" else: assert False + elif self.pipeline_tag == "qr_hpad": + if self.dvpad == "t": + return "a.hdim_v % 8 == 0" + else: + assert False elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": @@ -634,6 +660,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + _KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE @classmethod @@ -643,6 +670,12 @@ class FmhaFwdKernel: else: return "ck_tile::FmhaFwdKernel" + @classmethod + def _get_kernel_header(cls, pipeline_tag): + if pipeline_tag == "qr_hpad": + return cls._KERNEL_HEADER_QR_HPAD + return cls._KERNEL_HEADER + @classmethod def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): if pipeline_tag == "qr_async_trload_v3": @@ -651,7 +684,9 @@ class FmhaFwdKernel: return "fmha_fwd_create_kargs_and_grids" def render(self) -> str: - return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( + return type(self)._get_kernel_header(self.F_pipeline.tag) + type( + self + )._KERNEL_BODY_TEMPLATE.format( F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, @@ -1144,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): def supported_dtypes(cls) -> Tuple[str]: return cls._DT_FP16_BF16 + @classmethod + def get_rules(cls) -> List[CompatibilityRule]: + rules = super().get_rules() + + # For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile: + # the exact-hdim variant (dpad=dvpad=f) is much slower here. + def check_d128_tile_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if problem_ctx.dtype not in cls._DT_FP16_BF16: + return True + + if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128): + return True + + is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32 + pads_hdim = ( + kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t" + ) + exact_hdim = ( + kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f" + ) + + if is_64x32_tile: + return pads_hdim + + return exact_hdim + + rules.append(check_d128_tile_pipeline) + return rules + @classmethod def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: if dtype in cls._DT_FP16_BF16: @@ -1152,7 +1218,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), + (128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")), FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)], (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)] @@ -1179,7 +1246,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): # Keep only ttff/tttt for gfx11: ffff path is often similar or worse # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + if receipt == 1: + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip return pipelines @@ -1251,7 +1320,9 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): ): # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + if receipt == 1: + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 659bdd995b..a1a98867c6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum QSKSVS, QRKSVS_ASYNC_TRLOAD, QRKSVS_ASYNC_TRLOAD_V3, + QRKSVS_HPAD, }; template @@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr static constexpr const char* name = "qr_async_trload"; }; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_hpad"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index b207c62181..48c79177d4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -14,7 +14,9 @@ namespace ck_tile { // This pipeline is qkv all located in LDS -template +template struct BlockFmhaPipelineQRKSVS { using Problem = remove_cvref_t; @@ -54,17 +56,18 @@ struct BlockFmhaPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr auto QScaleEnum = Problem::QScaleEnum; - static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_; static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity; static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity; @@ -80,23 +83,29 @@ struct BlockFmhaPipelineQRKSVS (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV), + "padded vector load/store fast path only applies to padded head-dim kernels"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this - static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits::PackedSize - : Policy::template GetAlignmentQ(); - static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits::PackedSize - : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore) + ? numeric_traits::PackedSize + : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore) + ? numeric_traits::PackedSize + : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = []() { if constexpr(std::is_same_v) - return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + return (kPadHeadDimV && !kPaddedVecLoadStore) + ? 1 + : Policy::template GetAlignmentV(); else return kPadSeqLenK ? numeric_traits::PackedSize : Policy::template GetAlignmentV(); }(); static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + (kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); static constexpr index_t kAlignmentRandVal = @@ -548,8 +557,25 @@ struct BlockFmhaPipelineQRKSVS }); } - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile - { // tail + auto v_prefetch = decltype(load_tile(v_dram_window)){}; + enum class VPrefetchPoint + { + BeforeGemm0Tail, + AfterGemm0Tail, + AfterSoftmax + }; + +#if defined(__gfx11__) || defined(__gfx12__) + constexpr auto kVPrefetch = + kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail; +#else + constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail; +#endif + if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail) + { + load_tile(v_prefetch, v_dram_window); // prefetch load v tile + } + { // tail block_sync_lds(); run_gemm_0(number{}); block_sync_lds(); @@ -562,6 +588,10 @@ struct BlockFmhaPipelineQRKSVS run_gemm_0(number{}); } + if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail) + { + load_tile(v_prefetch, v_dram_window); + } // dequant auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) @@ -819,6 +849,11 @@ struct BlockFmhaPipelineQRKSVS randval_ptr, seq_offset, p_compute, randval_dram_window); } + if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax) + { + load_tile(v_prefetch, v_dram_window); + } + block_sync_lds(); if constexpr(std::is_same_v) { @@ -1098,4 +1133,7 @@ struct BlockFmhaPipelineQRKSVS } }; +template +using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS; + } // namespace ck_tile