mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Optimize fmha fwd decode & prefill for gfx950 (#2641)
* Fix for fwd/bwd kernel build filter * fix bwd code * save an example for __bf16 type * temp save, waiting for debug * tempsave, fmha_decode * temp save, change all instance to 1wave * fix async copytest bug * Add block_sync_lds_direct_load utility * fix the s_waitcnt_imm calculation * Improve s_waitcnt_imm calculation * fix vmcnt shift * add input validation and bug fix * remove unnecessary output * move test_copy into test * temp save * tempsave * compile pass * tempsave, trload+asyncload done * tempsave. asynccopy+trload sanity checked * remove unnecessary features * fix the lds alignment caused performance regression * enable prefill overload operator(). * remove all lds bankconflict with xor layouts * enable larger tile size; upgrade xor pattern * upgrade prefill pipeline; simple iglp; consistent data produce and consume order * small refactor * Load Q through lds, implement xor; * add vmcnt guard before load ktile * Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA * Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug * add __restrict__ to tr load * merge fa_decode pipeline into fmha_fwd api * remove unnecessary files; rename some files * Remove unnecessary changes * bug fix, clang format; * remove non-necessary change * fix clangformat with 18.1.3 * fix bugs * fix bug * fix bug on non-gfx950 * fix bugs in gemm * fix bug in pki4 * tempsave, update the blocksync functions * change the warp setting for hdim32 fmha fwd * clang format * fix conflict. disable all v-col instance for fmha fwd * Fix the bug * clang format --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
This commit is contained in:
@@ -42,6 +42,8 @@ struct BlockGemmARegBRegCRegV1
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr auto BlockGemmLoopOrder = Problem::BlockGemmLoopOrder;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
@@ -52,8 +54,9 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
@@ -86,17 +89,36 @@ struct BlockGemmARegBRegCRegV1
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
return a_block_dstr_encode;
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,17 +140,33 @@ struct BlockGemmARegBRegCRegV1
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,40 +251,82 @@ struct BlockGemmARegBRegCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::
|
||||
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::conditional_t<TransposeC,
|
||||
sequence<nIter, mIter>,
|
||||
sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -13,7 +14,8 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct BlockGemmProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -21,8 +23,9 @@ struct BlockGemmProblem
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user