mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4584 (commit 42efd1d)
[CK_TILE][FMHA] Support gfx11 ## Motivation Add support of gfx11 architectures (RDNA3) to FMHA. ## Technical Details Distributions (matrix elements to lane registers mapping) of gfx11 WMMA are completely different from distributions of gfx9 MFMA and gfx12 WMMA. There are two cases in FMHA where this difference matters: * usage of results (matrix C) of one GEMM as input (matrix A) of another GEMM. * random number generation for dropout (implementation for gfx9 MFMA, gfx12 WMMA and host validation produce the same results). Both cases are solved by a special remapping implemented using `__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`. Additional changes: * FMHA tests are now build and run only for those types for which instances exist (gfx11 supports only fp16 and bf16). * Two fixes for uninitialized values (`mask.sink` and `do_fp8_static_quant`): they may contain garbage resulting in incorrect dispatching logic, sometimes tests report that there are no instance available for current parameters. * Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when they are not requested (i.e. every time), likely has no effect on performance but makes disassembly a bit clearer. ## Test Plan ``` ninja test_ck_tile_fmha bin/test_ck_tile_fmha_fwd_fp16 bin/test_ck_tile_fmha_fwd_bf16 bin/test_ck_tile_fmha_bwd_fp16 bin/test_ck_tile_fmha_bwd_bf16 ``` ## Test Result All tests must pass (some tests may be skipped). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1915cdfcc2
commit
0d92fffedb
@@ -33,6 +33,42 @@ namespace ck_tile {
|
||||
namespace detail {
|
||||
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
|
||||
constexpr index_t philox_per_tile = 64;
|
||||
|
||||
// C distribution of gfx11 WMMA differs from C distribution of gfx9 MFMA and gfx12 WMMA.
|
||||
// This function deinterleaves the generated random values to make them compatible with other
|
||||
// architectures and verification code on host.
|
||||
template <index_t N>
|
||||
CK_TILE_DEVICE void PermuteBlockDropoutRandval(uint8_t (&random_uint8_t)[N])
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
static_for<0, N, 8>{}([&](auto i_offset) {
|
||||
array<uint8_t, 8> rs;
|
||||
static_for<0, 8, 1>{}([&](auto i) { rs.data[i] = random_uint8_t[i_offset + i]; });
|
||||
|
||||
const uint32_t r0 = rs.template get_as<uint32_t>(number<0>{});
|
||||
const uint32_t r1 = rs.template get_as<uint32_t>(number<1>{});
|
||||
|
||||
// Deinterleave values (even and odd indices)
|
||||
const uint32_t v0 = __builtin_amdgcn_perm(r1, r0, 0x06'04'02'00);
|
||||
const uint32_t v1 = __builtin_amdgcn_perm(r1, r0, 0x07'05'03'01);
|
||||
|
||||
// Swap rows (lane <-> lane ^ 16)
|
||||
const uint32_t w0 =
|
||||
__builtin_amdgcn_permlanex16(0, v0, 0x76543210, 0xfedcba98, false, true);
|
||||
const uint32_t w1 =
|
||||
__builtin_amdgcn_permlanex16(0, v1, 0x76543210, 0xfedcba98, false, true);
|
||||
|
||||
rs.template set_as<uint32_t>(number<0>{}, get_lane_id() < 16 ? v0 : w1);
|
||||
rs.template set_as<uint32_t>(number<1>{}, get_lane_id() < 16 ? w0 : v1);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto i) { random_uint8_t[i_offset + i] = rs.data[i]; });
|
||||
});
|
||||
#else
|
||||
static_assert(false, "PermuteBlockDropoutRandval is only for gfx11");
|
||||
ignore = random_uint8_t;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct NullBlockDropout
|
||||
@@ -295,6 +331,9 @@ struct BlockDropout
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
ph.get_random_16x8(random_uint8_t, ph_subsequence);
|
||||
}
|
||||
#if defined(__gfx11__)
|
||||
detail::PermuteBlockDropoutRandval(random_uint8_t);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -566,6 +605,9 @@ struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
ph.get_random_16x8(random_uint8_t, ph_subsequence);
|
||||
}
|
||||
#if defined(__gfx11__)
|
||||
detail::PermuteBlockDropoutRandval(random_uint8_t);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -1181,7 +1181,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
scale_rp_undrop,
|
||||
dropout);
|
||||
|
||||
#if defined(__gfx12__)
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
// Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly
|
||||
// placed in divergent branches used to store padded tensors (when some lanes are
|
||||
// inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
@@ -1692,8 +1693,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
auto pt_warp_tensor =
|
||||
auto p_warp_tensor =
|
||||
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
|
||||
auto pt_warp_tensor =
|
||||
make_static_distributed_tensor<typename Problem::GemmDataType>(AWarpDstr{});
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
@@ -1705,10 +1708,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
|
||||
p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
#if defined(__gfx11__)
|
||||
PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor);
|
||||
#else
|
||||
pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer();
|
||||
#endif
|
||||
pt_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
@@ -1742,8 +1750,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
auto dst_warp_tensor =
|
||||
auto ds_warp_tensor =
|
||||
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
|
||||
auto dst_warp_tensor =
|
||||
make_static_distributed_tensor<typename Problem::GemmDataType>(AWarpDstr{});
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
@@ -1755,10 +1765,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
|
||||
ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
#if defined(__gfx11__)
|
||||
PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor);
|
||||
#else
|
||||
dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer();
|
||||
#endif
|
||||
dst_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -675,8 +676,15 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(
|
||||
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
|
||||
#else
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -704,8 +705,15 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(
|
||||
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
|
||||
#else
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -717,8 +718,15 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
#if defined(__gfx11__)
|
||||
auto p = make_static_distributed_tensor<PDataType>(
|
||||
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
|
||||
PermuteWarpGemmCToA(
|
||||
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
|
||||
#else
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
|
||||
float v_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
|
||||
@@ -77,6 +77,7 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
|
||||
45
include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp
Normal file
45
include/ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// C distribution of gfx11 WMMA is not compatible with A distribution:
|
||||
// C: 2 lanes per row (lane and lane + 16), 8 values per lane are interleaved.
|
||||
// A: 1 lane per row, 16 values, lane and lane + 16 have the same values.
|
||||
// This function transforms one ditribution to another for GEMM-GEMM scenarios.
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE static constexpr void PermuteWarpGemmCToA(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
static_assert(sizeof(typename OutTensor::DataType) == 2);
|
||||
static_assert(std::is_same_v<typename OutTensor::DataType, typename InTensor::DataType>);
|
||||
|
||||
constexpr index_t n_out = OutTensor::get_thread_buffer_size();
|
||||
static_assert(n_out == InTensor::get_thread_buffer_size() * 2);
|
||||
|
||||
// Perm byte selectors are swapped for the second row (16 lanes) because it needs to be done
|
||||
// once instead to swapping w and v everytime
|
||||
const uint32_t byte_selector0 = get_lane_id() < 16 ? 0x05'04'01'00 : 0x01'00'05'04;
|
||||
const uint32_t byte_selector1 = get_lane_id() < 16 ? 0x07'06'03'02 : 0x03'02'07'06;
|
||||
static_for<0, n_out, 1>{}([&](auto i) {
|
||||
const auto v = in.get_thread_buffer().template get_as<uint32_t>(i);
|
||||
// Swap rows (lane <-> lane ^ 16)
|
||||
const auto w = __builtin_amdgcn_permlanex16(0, v, 0x76543210, 0xfedcba98, false, true);
|
||||
// Interleave values of lane and lane ^ 16
|
||||
out.get_thread_buffer().template set_as<uint32_t>(
|
||||
number<i * 2 + 0>{}, __builtin_amdgcn_perm(w, v, byte_selector0));
|
||||
out.get_thread_buffer().template set_as<uint32_t>(
|
||||
number<i * 2 + 1>{}, __builtin_amdgcn_perm(w, v, byte_selector1));
|
||||
});
|
||||
#else
|
||||
static_assert(false, "PermuteWarpGemmCToA is only for gfx11");
|
||||
ignore = out;
|
||||
ignore = in;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user