Add transposed lds descriptor

This commit is contained in:
PoYen, Chen
2024-06-12 03:46:41 +00:00
parent ba0bc1507c
commit a3fad6aae5

View File

@@ -80,8 +80,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
}
}
// 3d + padding
// [kMaxSplits x kM0]
// 3d + padding, [kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{
@@ -107,6 +106,32 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_lds_block_desc;
}
// 3d + padding, [kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTLdsBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
lse_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return lse_acc_t_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution()
{