mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Use xor transform to implement Q/K Lds descriptor for kKpack == 8 cases
This commit is contained in:
@@ -178,7 +178,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
return kKPerBlock * kNPerBlock;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -220,18 +220,47 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t DataTypeSize = sizeof(QDataType);
|
||||
|
||||
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<kMPerBlock * kKPack + kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock * MLdsLayer>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
|
||||
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
|
||||
q_lds_block_desc_k0_mldslayer_m_k1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return q_lds_block_desc;
|
||||
@@ -310,31 +339,55 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
constexpr index_t KSingleSmemElementSpaceSize = kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
constexpr index_t DataTypeSize = sizeof(KDataType);
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKPack + kKPack>{},
|
||||
make_tuple(number<kKPerBlock * kNPerBlock>{},
|
||||
number<kKPerBlock * NLdsLayer>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumKLdsBuffers>{}),
|
||||
make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto k_lds_block_desc_k0_nldslayer_n_k1 = transform_tensor_descriptor(
|
||||
k_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<NumKLdsBuffers>{}),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}));
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_k0_nldslayer_n_k1,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
|
||||
Reference in New Issue
Block a user