mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Add transposed lds descriptor
This commit is contained in:
@@ -80,8 +80,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
}
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
// [kMaxSplits x kM0]
|
||||
// 3d + padding, [kMaxSplits, kM0]
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
|
||||
{
|
||||
@@ -107,6 +106,32 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
return lse_acc_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding, [kM0, kMaxSplits]
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTLdsBlockDescriptor()
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kNPerBlock = Problem::kM0;
|
||||
constexpr index_t NPack = 16 / sizeof(LSEDataType);
|
||||
|
||||
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
|
||||
lse_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return lse_acc_t_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user