From a3fad6aae5c6d6a8dc5b8e318522aff8d8b3fbe9 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 12 Jun 2024 03:46:41 +0000 Subject: [PATCH] Add transposed lds descriptor --- ...plitkv_combine_pipeline_default_policy.hpp | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 89e4fa29ef..b57c694678 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -80,8 +80,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy } } - // 3d + padding - // [kMaxSplits x kM0] + // 3d + padding, [kMaxSplits, kM0] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() { @@ -107,6 +106,32 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy return lse_acc_lds_block_desc; } + // 3d + padding, [kM0, kMaxSplits] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTLdsBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return lse_acc_t_lds_block_desc; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution() {