[CK_TILE][FMHA] Support gfx11 (#4584)

## Motivation

Add support of gfx11 architectures (RDNA3) to FMHA.

## Technical Details

Distributions (matrix elements to lane registers mapping) of gfx11 WMMA
are completely different from distributions of gfx9 MFMA and gfx12 WMMA.
There are two cases in FMHA where this difference matters:
* usage of results (matrix C) of one GEMM as input (matrix A) of another
GEMM.
* random number generation for dropout (implementation for gfx9 MFMA,
gfx12 WMMA and host validation produce the same results).

Both cases are solved by a special remapping implemented using
`__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`.

Additional changes:
* FMHA tests are now build and run only for those types for which
instances exist (gfx11 supports only fp16 and bf16).
* Two fixes for uninitialized values (`mask.sink` and
`do_fp8_static_quant`): they may contain garbage resulting in incorrect
dispatching logic, sometimes tests report that there are no instance
available for current parameters.
* Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when
they are not requested (i.e. every time), likely has no effect on
performance but makes disassembly a bit clearer.

## Test Plan

```
ninja test_ck_tile_fmha

bin/test_ck_tile_fmha_fwd_fp16
bin/test_ck_tile_fmha_fwd_bf16
bin/test_ck_tile_fmha_bwd_fp16
bin/test_ck_tile_fmha_bwd_bf16
```

## Test Result

All tests must pass (some tests may be skipped).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Anton Gorenko
2026-02-21 06:15:10 +05:00
committed by GitHub
parent c98c68cd2d
commit ce6acc5f66
19 changed files with 296 additions and 21 deletions

View File

@@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
* Added FP8 block scale quantization for FMHA forward kernel.
* Added gfx11 support for FMHA.
### Changed

View File

@@ -3,9 +3,9 @@
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 and gfx12 archs are supported by FMHA
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx1[12]")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx11, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()

View File

@@ -457,6 +457,24 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
return results
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if tr_load == "t":
return []
if dtype in ["fp16", "bf16"]:
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
] # fmt: skip
return []
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -483,6 +501,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -1094,6 +1094,50 @@ class KernelComponentFactoryGfx950(
return pipelines
class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
arch = ArchTrait("gfx11")
_DT_FP16_BF16 = ("fp16", "bf16")
@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_FP16_BF16
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_FP16_BF16:
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
else:
raise ValueError(f"unsupported dtype={dtype}")
@classmethod
def get_pipelines(
cls, dtype, hdim, hdim_v, receipt, mask_impl
) -> List[FmhaFwdPipeline]:
pipelines = []
if dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
return pipelines
class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
arch = ArchTrait("gfx12")
@@ -1183,6 +1227,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -388,6 +388,22 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
arch = ArchTrait("gfx9")
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return KernelComponentFactoryBase.get_hdim_tile_size_dict(dtype)
return None
@staticmethod
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
if dtype in ["fp16", "bf16"]:
return KernelComponentFactoryBase.get_pipelines(dtype, hdim)
return []
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -398,6 +414,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -836,6 +836,23 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
return None
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
else:
return None
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -865,6 +882,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -597,6 +597,24 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
return None
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
# "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
else:
return None
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -628,6 +646,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -1635,8 +1635,8 @@ struct fmha_fwd_splitkv_traits
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
bool has_sink = false;
bool do_fp8_static_quant = false;
bool has_sink = false;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,

View File

@@ -73,6 +73,7 @@ struct mask_info
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = left_size;
tmp.right = right_size;
tmp.sink = 0;
}
else if(t == "t" || t == "b" || t == "g")
{
@@ -147,7 +148,10 @@ struct mask_info
}
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
tmp.type = mask_enum::no_mask;
tmp.left = -1;
tmp.right = -1;
tmp.sink = 0;
}
else if(str == "1" || str == "t")
{

View File

@@ -911,14 +911,15 @@ struct WaitcntLayoutGfx12
};
struct WaitcntLayoutGfx11
{ // vm[15:10] (6), lgkm[9:4] (6), exp unused
{ // vm[15:10] (6), lgkm[9:4] (6), exp [2:0] (3)
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F;
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F;
CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07;
CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); }
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); }
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return (c & EXP_MASK); }
};
struct WaitcntLayoutLegacy
@@ -952,10 +953,14 @@ using Waitcnt = WaitcntLayoutLegacy;
struct waitcnt_arg
{
// kMax* exposed for callers; match field widths per-arch
#if defined(__gfx12__) || defined(__gfx11__)
#if defined(__gfx12__)
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
#elif defined(__gfx11__)
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
#else
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split)
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits
@@ -981,8 +986,8 @@ struct waitcnt_arg
{
if constexpr(Waitcnt::HAS_EXP)
{
// EXP_MASK only exists on legacy
#if !defined(__gfx12__) && !defined(__gfx11__)
// EXP_MASK only exists on pre-gfx12
#if !defined(__gfx12__)
static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range");
return Waitcnt::pack_exp(cnt);
#else

View File

@@ -33,6 +33,42 @@ namespace ck_tile {
namespace detail {
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
constexpr index_t philox_per_tile = 64;
// C distribution of gfx11 WMMA differs from C distribution of gfx9 MFMA and gfx12 WMMA.
// This function deinterleaves the generated random values to make them compatible with other
// architectures and verification code on host.
template <index_t N>
CK_TILE_DEVICE void PermuteBlockDropoutRandval(uint8_t (&random_uint8_t)[N])
{
#if defined(__gfx11__)
static_for<0, N, 8>{}([&](auto i_offset) {
array<uint8_t, 8> rs;
static_for<0, 8, 1>{}([&](auto i) { rs.data[i] = random_uint8_t[i_offset + i]; });
const uint32_t r0 = rs.template get_as<uint32_t>(number<0>{});
const uint32_t r1 = rs.template get_as<uint32_t>(number<1>{});
// Deinterleave values (even and odd indices)
const uint32_t v0 = __builtin_amdgcn_perm(r1, r0, 0x06'04'02'00);
const uint32_t v1 = __builtin_amdgcn_perm(r1, r0, 0x07'05'03'01);
// Swap rows (lane <-> lane ^ 16)
const uint32_t w0 =
__builtin_amdgcn_permlanex16(0, v0, 0x76543210, 0xfedcba98, false, true);
const uint32_t w1 =
__builtin_amdgcn_permlanex16(0, v1, 0x76543210, 0xfedcba98, false, true);
rs.template set_as<uint32_t>(number<0>{}, get_lane_id() < 16 ? v0 : w1);
rs.template set_as<uint32_t>(number<1>{}, get_lane_id() < 16 ? w0 : v1);
static_for<0, 8, 1>{}([&](auto i) { random_uint8_t[i_offset + i] = rs.data[i]; });
});
#else
static_assert(false, "PermuteBlockDropoutRandval is only for gfx11");
ignore = random_uint8_t;
#endif
}
} // namespace detail
struct NullBlockDropout
@@ -295,6 +331,9 @@ struct BlockDropout
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
#if defined(__gfx11__)
detail::PermuteBlockDropoutRandval(random_uint8_t);
#endif
}
else
{
@@ -566,6 +605,9 @@ struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
#if defined(__gfx11__)
detail::PermuteBlockDropoutRandval(random_uint8_t);
#endif
}
else
{

View File

@@ -1181,7 +1181,7 @@ struct FmhaBwdDQDKDVKernel
scale_rp_undrop,
dropout);
#if defined(__gfx12__)
#if defined(__gfx11__) || defined(__gfx12__)
// Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly
// placed in divergent branches used to store padded tensors (when some lanes are
// inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
@@ -1692,8 +1693,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto pt_warp_tensor =
auto p_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
auto pt_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(AWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
@@ -1705,10 +1708,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
#if defined(__gfx11__)
PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor);
#else
pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer();
#endif
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
@@ -1742,8 +1750,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto dst_warp_tensor =
auto ds_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
auto dst_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(AWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
@@ -1755,10 +1765,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
#if defined(__gfx11__)
PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor);
#else
dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer();
#endif
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -675,8 +676,15 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
#if defined(__gfx11__)
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
#else
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#endif
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -704,8 +705,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
#if defined(__gfx11__)
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
#else
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#endif
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)

View File

@@ -7,6 +7,7 @@
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -717,8 +718,15 @@ struct BlockFmhaPipelineQRKSVS
move_tile_window(v_dram_window, {0, kK1});
#if defined(__gfx11__)
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
#else
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#endif
float v_descale = 1.0f;
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)

View File

@@ -77,6 +77,7 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"

View File

@@ -0,0 +1,45 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// C distribution of gfx11 WMMA is not compatible with A distribution:
// C: 2 lanes per row (lane and lane + 16), 8 values per lane are interleaved.
// A: 1 lane per row, 16 values, lane and lane + 16 have the same values.
// This function transforms one ditribution to another for GEMM-GEMM scenarios.
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE static constexpr void PermuteWarpGemmCToA(OutTensor& out, const InTensor& in)
{
#if defined(__gfx11__)
static_assert(sizeof(typename OutTensor::DataType) == 2);
static_assert(std::is_same_v<typename OutTensor::DataType, typename InTensor::DataType>);
constexpr index_t n_out = OutTensor::get_thread_buffer_size();
static_assert(n_out == InTensor::get_thread_buffer_size() * 2);
// Perm byte selectors are swapped for the second row (16 lanes) because it needs to be done
// once instead to swapping w and v everytime
const uint32_t byte_selector0 = get_lane_id() < 16 ? 0x05'04'01'00 : 0x01'00'05'04;
const uint32_t byte_selector1 = get_lane_id() < 16 ? 0x07'06'03'02 : 0x03'02'07'06;
static_for<0, n_out, 1>{}([&](auto i) {
const auto v = in.get_thread_buffer().template get_as<uint32_t>(i);
// Swap rows (lane <-> lane ^ 16)
const auto w = __builtin_amdgcn_permlanex16(0, v, 0x76543210, 0xfedcba98, false, true);
// Interleave values of lane and lane ^ 16
out.get_thread_buffer().template set_as<uint32_t>(
number<i * 2 + 0>{}, __builtin_amdgcn_perm(w, v, byte_selector0));
out.get_thread_buffer().template set_as<uint32_t>(
number<i * 2 + 1>{}, __builtin_amdgcn_perm(w, v, byte_selector1));
});
#else
static_assert(false, "PermuteWarpGemmCToA is only for gfx11");
ignore = out;
ignore = in;
#endif
}
} // namespace ck_tile

View File

@@ -1,11 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt
if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx12")
return()
endif()
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
@@ -18,9 +13,19 @@ function(add_gtest_fwd test_group)
set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16")
set(CPP_TYPE_fp32 "FmhaFwdFp32")
set(sources)
if(TARGET ${FMHA_FWD_INSTANCES})
get_target_property(sources ${FMHA_FWD_INSTANCES} SOURCES)
message(VERBOSE "${FMHA_FWD_INSTANCES} SOURCES ${sources}")
endif()
set(all_tests)
foreach(type ${V_TYPES})
set(name "${test_group}_${type}")
if(NOT sources MATCHES "_${type}_")
message(STATUS "No FMHA FWD instances for ${type}, skip ${name}")
continue()
endif()
add_gtest_executable(${name} test_fmha_fwd.cpp)
get_test_property(${name} LABELS COMMON_LABELS)
set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")
@@ -38,9 +43,19 @@ function(add_gtest_bwd test_group)
set(CPP_TYPE_bf16 "FmhaBwdBf16")
set(CPP_TYPE_fp32 "FmhaBwdFp32")
set(sources)
if(TARGET ${FMHA_BWD_INSTANCES})
get_target_property(sources ${FMHA_BWD_INSTANCES} SOURCES)
message(VERBOSE "${FMHA_BWD_INSTANCES} SOURCES ${sources}")
endif()
set(all_tests)
foreach(type ${V_TYPES})
set(name "${test_group}_${type}")
if(NOT sources MATCHES "_${type}_")
message(STATUS "No FMHA BWD instances for ${type}, skip ${name}")
continue()
endif()
add_gtest_executable(${name} test_fmha_bwd.cpp)
get_test_property(${name} LABELS COMMON_LABELS)
set_tests_properties(${name} PROPERTIES LABELS "${COMMON_LABELS};${TEST_NAME};${test_group}")