mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#6156 (commit 367565a)
[CK_TILE] Optimize FMHA head-dim padded path on gfx11/gfx12 (#6156) ## Motivation On gfx11/gfx12, FMHA forward kernels that require head-dim padding show a large performance drop compared to the exact-head-dim path. In practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too far off the fast path. This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while keeping the behavior for other GPUs unchanged. ## Technical Details - Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path for gfx11/gfx12. - For `receipt=0`, keep support conservative and only enable the padded fast path for vector-safe cases (`head_dim % 8 == 0`), matching the existing assumption used on other GPUs. - Move `v_prefetch` later only for the head-dim-padded path on gfx11/gfx12. This reduces live ranges and removes the register-spill behavior seen in the earlier scheduling. - Enable the buffer-load OOB check offset trick for the padded path on gfx11/gfx12. ## Test Plan ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16 -d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result Observed padded-head-dim performance improvements for HDIM=72/80: - gfx11: about ~3.5x - gfx1151: about ~2.0x - gfx12: about ~1.3x ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
7d6c8e5afa
commit
4c0e73ab12
@@ -139,6 +139,7 @@ LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
@@ -147,6 +148,7 @@ PIPELINE_MAP = {
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
|
||||
@@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && \
|
||||
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
|
||||
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \
|
||||
defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__))
|
||||
#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK)
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
#endif
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY_TEMPLATE = """
|
||||
#include <iostream>
|
||||
|
||||
@@ -300,7 +316,7 @@ class FmhaFwdApiTrait:
|
||||
return "true" # always support
|
||||
else:
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
@@ -323,7 +339,7 @@ class FmhaFwdApiTrait:
|
||||
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
@@ -344,6 +360,11 @@ class FmhaFwdApiTrait:
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag == "qr_hpad":
|
||||
if self.dpad == "t":
|
||||
return "a.hdim_q % 8 == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
@@ -361,6 +382,11 @@ class FmhaFwdApiTrait:
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag == "qr_hpad":
|
||||
if self.dvpad == "t":
|
||||
return "a.hdim_v % 8 == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
@@ -634,6 +660,7 @@ class FmhaFwdKernel:
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
|
||||
_KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER
|
||||
_KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD
|
||||
_KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE
|
||||
|
||||
@classmethod
|
||||
@@ -643,6 +670,12 @@ class FmhaFwdKernel:
|
||||
else:
|
||||
return "ck_tile::FmhaFwdKernel"
|
||||
|
||||
@classmethod
|
||||
def _get_kernel_header(cls, pipeline_tag):
|
||||
if pipeline_tag == "qr_hpad":
|
||||
return cls._KERNEL_HEADER_QR_HPAD
|
||||
return cls._KERNEL_HEADER
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
|
||||
if pipeline_tag == "qr_async_trload_v3":
|
||||
@@ -651,7 +684,9 @@ class FmhaFwdKernel:
|
||||
return "fmha_fwd_create_kargs_and_grids"
|
||||
|
||||
def render(self) -> str:
|
||||
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
|
||||
return type(self)._get_kernel_header(self.F_pipeline.tag) + type(
|
||||
self
|
||||
)._KERNEL_BODY_TEMPLATE.format(
|
||||
F_kname=self.name,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
@@ -1144,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return cls._DT_FP16_BF16
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = super().get_rules()
|
||||
|
||||
# For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile:
|
||||
# the exact-hdim variant (dpad=dvpad=f) is much slower here.
|
||||
def check_d128_tile_pipeline(
|
||||
problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> bool:
|
||||
if problem_ctx.dtype not in cls._DT_FP16_BF16:
|
||||
return True
|
||||
|
||||
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
|
||||
return True
|
||||
|
||||
is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
|
||||
pads_hdim = (
|
||||
kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
|
||||
)
|
||||
exact_hdim = (
|
||||
kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
|
||||
)
|
||||
|
||||
if is_64x32_tile:
|
||||
return pads_hdim
|
||||
|
||||
return exact_hdim
|
||||
|
||||
rules.append(check_d128_tile_pipeline)
|
||||
return rules
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
@@ -1152,7 +1218,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
|
||||
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
|
||||
@@ -1179,7 +1246,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
|
||||
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -1251,7 +1320,9 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
):
|
||||
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
|
||||
Reference in New Issue
Block a user