[CK_TILE] FMHA BWD Optimization For GFX950 (#2628)

* simplify fmha_bwd_kernel MakeKargs & dq_dram_window

* simply duplicate

* trload pipeline

* Try two-stage

* add prefetch

* optimize & iglp
This commit is contained in:
Yi DING
2025-08-12 11:11:55 +08:00
committed by GitHub
parent a7badc6ec5
commit 4fde1646e5
16 changed files with 2216 additions and 586 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -11,7 +11,9 @@ namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy,
bool TransposeC_ = false>
struct BlockGemmARegBRegCRegV1
{
private:
@@ -44,8 +46,9 @@ struct BlockGemmARegBRegCRegV1
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
static constexpr bool TransposeC = TransposeC_;
using Traits = GemmTraits_<Problem, Policy>;
@@ -131,6 +134,7 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
if constexpr(UseDefaultScheduler)
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
@@ -138,7 +142,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
@@ -152,7 +156,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
@@ -172,25 +176,19 @@ struct BlockGemmARegBRegCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode();
constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode();
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode();
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
@@ -219,7 +217,6 @@ struct BlockGemmARegBRegCRegV1
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
@@ -227,16 +224,16 @@ struct BlockGemmARegBRegCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
using c_iter_idx = std::
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
@@ -244,7 +241,7 @@ struct BlockGemmARegBRegCRegV1
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
@@ -254,6 +251,7 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
if constexpr(UseDefaultScheduler)
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
@@ -261,7 +259,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -277,7 +275,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(