mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Improve headdim96 performance for fmha-bwd (#1573)
* Add kQKHeaddimForGemmN and kVHeaddimForGemmN in order to support headdim 96
* Remove the using of MakeKRegBlockDescriptor and MakeVRegBlockDescriptor
* Fix in bwd_piple_default_policy
* Remove kQKHeaddim and rename kQKHeaddimForGemmN to kQKHeaddim in the bwd kernel and pipelines
* Replace kVHeaddimForGemmN by kVHeaddim and kDoDvHeaddim
* Update to hd96 tile settings
* Add smoke test scripts for fmha-bwd hd96
* Revert "Add smoke test scripts for fmha-bwd hd96"
This reverts commit 7ca7e1a93d.
* Remove hd96 tile settings in fmha_bwd codegen to save compiling
* Fix lost code line in bwd_pipeline_default_policy
* Merge kDoDvHeaddim/kPadHeadDimDoDv to kVHeaddim/kPadHeadDimV and remove TileFmhaBwdTraits
* Rename KRegSliceBlockDescriptor/VRegSliceBlockDescriptor to KRegBlockDescriptor/VRegBlockDescriptor
* tiny adjustments
---------
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: danyao12 <Dan.Yao@amd.com>
This commit is contained in:
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
block_sync_lds();
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
|
||||
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
auto q_dram_window =
|
||||
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
|
||||
|
||||
@@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
|
||||
|
||||
@@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
|
||||
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
|
||||
|
||||
@@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
return total_pixels / GetAlignmentK<Problem>();
|
||||
@@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
|
||||
{
|
||||
@@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
|
||||
{
|
||||
@@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
@@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
// Hold full block data
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
|
||||
@@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
|
||||
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
|
||||
|
||||
static constexpr index_t WarpGemmM =
|
||||
@@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
// Compute
|
||||
static constexpr index_t Gemm0MFMA =
|
||||
kM0 * kN0 * kQKHeaddim /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm1MFMA =
|
||||
kM0 * kN0 * kVHeaddim /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kN0 * kVHeaddim * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm3MFMA =
|
||||
kN0 * kQKHeaddim * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
@@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
static constexpr index_t SGradT_LDS_READ_P1 =
|
||||
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t Q_LDS_READ =
|
||||
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
static constexpr index_t SGradT_LDS_READ_P2 =
|
||||
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t OGrad_LDS_READ =
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
|
||||
// LDS Write
|
||||
|
||||
Reference in New Issue
Block a user