mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Optimize FMHA head-dim padded path on gfx11/gfx12 (#6156)
## Motivation
On gfx11/gfx12, FMHA forward kernels that require head-dim padding show
a large performance drop compared to the exact-head-dim path. In
practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too
far off the fast path.
This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while
keeping the behavior for other GPUs unchanged.
## Technical Details
- Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path
for gfx11/gfx12.
- For `receipt=0`, keep support conservative and only enable the padded
fast path for vector-safe cases (`head_dim % 8 == 0`), matching the
existing assumption used on other GPUs.
- Move `v_prefetch` later only for the head-dim-padded path on
gfx11/gfx12. This reduces live ranges and removes the register-spill
behavior seen in the earlier scheduling.
- Enable the buffer-load OOB check offset trick for the padded path on
gfx11/gfx12.
## Test Plan
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
## Test Result
Observed padded-head-dim performance improvements for HDIM=72/80:
- gfx11: about ~3.5x
- gfx1151: about ~2.0x
- gfx12: about ~1.3x
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -139,6 +139,7 @@ LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
@@ -147,6 +148,7 @@ PIPELINE_MAP = {
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
|
||||
@@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && \
|
||||
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
|
||||
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \
|
||||
defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__))
|
||||
#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK)
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
#endif
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY_TEMPLATE = """
|
||||
#include <iostream>
|
||||
|
||||
@@ -300,7 +316,7 @@ class FmhaFwdApiTrait:
|
||||
return "true" # always support
|
||||
else:
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
@@ -323,7 +339,7 @@ class FmhaFwdApiTrait:
|
||||
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
@@ -344,6 +360,11 @@ class FmhaFwdApiTrait:
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag == "qr_hpad":
|
||||
if self.dpad == "t":
|
||||
return "a.hdim_q % 8 == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
@@ -361,6 +382,11 @@ class FmhaFwdApiTrait:
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag == "qr_hpad":
|
||||
if self.dvpad == "t":
|
||||
return "a.hdim_v % 8 == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
@@ -634,6 +660,7 @@ class FmhaFwdKernel:
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
|
||||
_KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER
|
||||
_KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD
|
||||
_KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE
|
||||
|
||||
@classmethod
|
||||
@@ -643,6 +670,12 @@ class FmhaFwdKernel:
|
||||
else:
|
||||
return "ck_tile::FmhaFwdKernel"
|
||||
|
||||
@classmethod
|
||||
def _get_kernel_header(cls, pipeline_tag):
|
||||
if pipeline_tag == "qr_hpad":
|
||||
return cls._KERNEL_HEADER_QR_HPAD
|
||||
return cls._KERNEL_HEADER
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
|
||||
if pipeline_tag == "qr_async_trload_v3":
|
||||
@@ -651,7 +684,9 @@ class FmhaFwdKernel:
|
||||
return "fmha_fwd_create_kargs_and_grids"
|
||||
|
||||
def render(self) -> str:
|
||||
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
|
||||
return type(self)._get_kernel_header(self.F_pipeline.tag) + type(
|
||||
self
|
||||
)._KERNEL_BODY_TEMPLATE.format(
|
||||
F_kname=self.name,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
@@ -1144,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return cls._DT_FP16_BF16
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = super().get_rules()
|
||||
|
||||
# For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile:
|
||||
# the exact-hdim variant (dpad=dvpad=f) is much slower here.
|
||||
def check_d128_tile_pipeline(
|
||||
problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> bool:
|
||||
if problem_ctx.dtype not in cls._DT_FP16_BF16:
|
||||
return True
|
||||
|
||||
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
|
||||
return True
|
||||
|
||||
is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
|
||||
pads_hdim = (
|
||||
kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
|
||||
)
|
||||
exact_hdim = (
|
||||
kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
|
||||
)
|
||||
|
||||
if is_64x32_tile:
|
||||
return pads_hdim
|
||||
|
||||
return exact_hdim
|
||||
|
||||
rules.append(check_d128_tile_pipeline)
|
||||
return rules
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
@@ -1152,7 +1218,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
|
||||
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
|
||||
@@ -1179,7 +1246,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
|
||||
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -1251,7 +1320,9 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
):
|
||||
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
if receipt == 1:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
|
||||
@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
|
||||
QSKSVS,
|
||||
QRKSVS_ASYNC_TRLOAD,
|
||||
QRKSVS_ASYNC_TRLOAD_V3,
|
||||
QRKSVS_HPAD,
|
||||
};
|
||||
|
||||
template <BlockFmhaPipelineEnum>
|
||||
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
|
||||
static constexpr const char* name = "qr_async_trload";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_HPAD>
|
||||
{
|
||||
static constexpr const char* name = "qr_hpad";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy,
|
||||
bool PaddedVecLoadStore_ = false>
|
||||
struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
@@ -54,17 +56,18 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
|
||||
@@ -80,23 +83,29 @@ struct BlockFmhaPipelineQRKSVS
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
|
||||
"padded vector load/store fast path only applies to padded head-dim kernels");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits<QDataType>::PackedSize
|
||||
: Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits<KDataType>::PackedSize
|
||||
: Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
|
||||
? numeric_traits<QDataType>::PackedSize
|
||||
: Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
|
||||
? numeric_traits<KDataType>::PackedSize
|
||||
: Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
return (kPadHeadDimV && !kPaddedVecLoadStore)
|
||||
? 1
|
||||
: Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? numeric_traits<VDataType>::PackedSize
|
||||
: Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
(kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
@@ -548,8 +557,25 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
auto v_prefetch = decltype(load_tile(v_dram_window)){};
|
||||
enum class VPrefetchPoint
|
||||
{
|
||||
BeforeGemm0Tail,
|
||||
AfterGemm0Tail,
|
||||
AfterSoftmax
|
||||
};
|
||||
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
constexpr auto kVPrefetch =
|
||||
kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
|
||||
#else
|
||||
constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
|
||||
#endif
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window); // prefetch load v tile
|
||||
}
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<k0_loops - 2>{});
|
||||
block_sync_lds();
|
||||
@@ -562,6 +588,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
run_gemm_0(number<k0_loops - 1>{});
|
||||
}
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window);
|
||||
}
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
@@ -819,6 +849,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
|
||||
{
|
||||
load_tile(v_prefetch, v_dram_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -1098,4 +1133,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
|
||||
using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS<Problem_, Policy_, true>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user