mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
[rocm-libraries] ROCm/rocm-libraries#7500 (commit f5cd4fd)
[CK_TILE][FMHA] Optimize long-context decoding on gfx11/12 (#7500) ## Motivation Relevant issue: ROCM-22065 FMHA has less-than-optimal performance of long-context decoding (i.e. when seqlen_q = 1) on gfx11/12. This PR optimizes the splitkv pipeline and configs for such scenarios. ## Technical Details Optimizations applied in this PR: 1. use tiles with smaller M0 (16 vs 64), these tiles are used when seqlen_q <= 16 2. adapt qr_nwarp_sshuffle pipeline for gfx11, it allows to use more warps even for M0 = 16 (the qr pipeline parallelizes work between warps in M dim so with M0 = 16 it allows to use only 1 warp) 3. enable kMergeNumHeadGroupsSeqLenQ (an optimization that merges one group of heads in GQA) for all hdim values, not only 128 4. increase the number of splits (multiply by the number of head groups) if (3) is used 5. increase the number of splits for RDNAs (`multiProcessorCount` is the number of WGPs on RDNAs, not CUs, so it should be doubled to have meaning similar to CDNAs) Performance on gfx1151: | Case | develop (GB/s) | This PR (GB/s) | |:-------|-------:|-------:| | [fp16\|group\|bshd] b:1, h:32/32, s:1/45056, d:64/64 | 127.58 | 183.11 | | [fp16\|group\|bhsd] b:1, h:32/32, s:1/45056, d:64/64 | 153.64 | 215.02 | | [fp16\|group\|bshd] b:1, h:16/8, s:1/77184, d:128/128 | 120.51 | 225.76 | | [fp16\|group\|bhsd] b:1, h:16/8, s:1/77184, d:128/128 | 130.62 | 223.84 | | [fp16\|group\|bshd] b:1, h:32/32, s:1/9600, d:128/128 | 82.65 | 138.44 | | [fp16\|group\|bhsd] b:1, h:32/32, s:1/9600, d:128/128 | 105.75 | 220.45 | | [fp16\|group\|bshd] b:1, h:8/1, s:1/401024, d:256/256 | 16.27 | 187.89 | | [fp16\|group\|bhsd] b:1, h:8/1, s:1/401024, d:256/256 | 16.28 | 188.19 | ## Test Plan An additional test case is added to the exiting test. It uses seqlen_q = 1, GQA, no mask to trigger the changes ``` ninja test_ck_tile_fmha_fwd_fp16 && bin/test_ck_tile_fmha_fwd_fp16 --gtest_filter="*SplitKV* ninja test_ck_tile_fmha_fwd_bf16 && bin/test_ck_tile_fmha_fwd_bf16 --gtest_filter="*SplitKV* ``` Manual testing can be done with these commands: ``` bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=64 -s=1 -s_k=$((352 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=16 -h_k=8 -d=128 -s=1 -s_k=$((603 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=128 -s=1 -s_k=$((75 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=8 -h_k=1 -d=256 -s=1 -s_k=$((3133 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 ``` ## Test Result All the tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
01bd52bdb5
commit
7ecbf82708
@@ -128,7 +128,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
|
||||
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
if constexpr ({F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||
@@ -283,7 +283,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
@@ -364,6 +364,14 @@ class FmhaFwdSplitKVApiTrait:
|
||||
else:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
if self.mode == "group":
|
||||
return f"a.max_seqlen_q <= {self.bm0}"
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
@@ -561,6 +569,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
|
||||
max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0)
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(pool_by_hdim):
|
||||
inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(
|
||||
@@ -579,6 +588,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune(max_bm0),
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
@@ -763,6 +773,7 @@ class KernelComponentFactoryBase:
|
||||
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline("qr_nwarp_sshuffle", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
@@ -846,11 +857,15 @@ class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"128": [FmhaFwdTileSize( 16, 64, 32, 128, 32, 128, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"256": [FmhaFwdTileSize( 16, 64, 32, 256, 32, 256, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
@@ -863,11 +878,15 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"128": [FmhaFwdTileSize( 16, 128, 32, 128, 32, 128, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"256": [FmhaFwdTileSize( 16, 128, 32, 256, 32, 256, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
return {
|
||||
@@ -930,11 +949,17 @@ def get_fwd_splitkv_blobs(
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
tiles = d[hdim_str]
|
||||
if not isinstance(tiles, list):
|
||||
tiles = [tiles]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, mask_impl)
|
||||
):
|
||||
# Use qr_nwarp_sshuffle with multiple N warps and qr otherwise
|
||||
if (tile.F_rn0 != 1) != (pipeline.tag == "qr_nwarp_sshuffle"):
|
||||
continue
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
|
||||
@@ -165,8 +165,10 @@ int override_num_splits_if_necessary(
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
// props.multiProcessorCount for >=gfx10 is the number of WGPs (each has 2 CUs)
|
||||
const int num_blocks_per_SM = props.warpSize == 32 ? 4 : 2;
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * num_blocks_per_SM, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
@@ -648,8 +650,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
// legalize num_splits according to other options
|
||||
if(num_splits < 1)
|
||||
{
|
||||
int nhead_merged = nhead;
|
||||
int max_seqlen_q_merged = max_seqlen_q;
|
||||
// When max_seqlen_q == 1 and multiple head groups are merged (kMergeNumHeadGroupsSeqLenQ)
|
||||
// then more splits are required
|
||||
if(bias.type == bias_enum::no_bias && mask.type == mask_enum::no_mask &&
|
||||
max_seqlen_q == 1 && nhead_k < nhead)
|
||||
{
|
||||
nhead_merged = nhead_k;
|
||||
max_seqlen_q_merged = max_seqlen_q * (nhead / nhead_k);
|
||||
}
|
||||
num_splits = override_num_splits_if_necessary(
|
||||
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
|
||||
batch, nhead_merged, max_seqlen_q_merged, hdim_v, p_drop, num_splits);
|
||||
}
|
||||
if(128 < num_splits)
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -257,7 +258,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
clear_tile(o_acc);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
|
||||
{
|
||||
set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E});
|
||||
set_tile(m, SMPLComputeDataType{sink_v * static_cast<float>(C_LOG2E)});
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
@@ -698,8 +699,15 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
|
||||
#if defined(__gfx11__)
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(
|
||||
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
|
||||
#else
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -163,6 +161,25 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
#if defined(__gfx11__)
|
||||
// Keep C distribution and replicate data for NWarp to prevent doubling registers
|
||||
// PermuteWarpGemmCToA will convert C distribution to A for matrix P later
|
||||
constexpr index_t K1 = kKPerBlock / WG::kM;
|
||||
constexpr index_t K0 = kTileK / kKPerBlock;
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / WG::kN;
|
||||
|
||||
constexpr auto s2_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 1>>{};
|
||||
|
||||
constexpr auto s2_block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
s2_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
#else
|
||||
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
|
||||
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
@@ -179,7 +196,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
|
||||
tuple<sequence<1, 0>, sequence<2, 2>>,
|
||||
sequence<1, 2, 2, 2>,
|
||||
sequence<0, 0, 1, 3>>{};
|
||||
|
||||
#endif
|
||||
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
|
||||
|
||||
return s2_block_dstr;
|
||||
|
||||
@@ -735,6 +735,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
Values(3, 4),
|
||||
Values(std::tuple{4, 3, 1, 200, 1024, "0"},
|
||||
std::tuple{2, 2, -1, 512, 2000, "0"},
|
||||
std::tuple{2, 8, 2, 1, 1024, "0"},
|
||||
std::tuple{3, 2, -1, 230, 899, "t:128,128"})));
|
||||
|
||||
TEST_P(SplitKV, DataTypeConfig)
|
||||
|
||||
Reference in New Issue
Block a user