mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Update in K-Lds laying-out to consider for both WarpGemm-32x32x16 and WarpGemm-16x16x16
This commit is contained in:
@@ -51,11 +51,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetQKWarpGemmKPerThreadSize()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::kKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
return 8 / sizeof(QKVDataType);
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -66,9 +78,18 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
@@ -105,40 +126,76 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
constexpr index_t KSingleSmemElementSpaceSize =
|
||||
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
constexpr index_t KSingleSmemElementSpaceSize = kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKVector + kKPack>{},
|
||||
number<kNPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKPack + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
constexpr index_t KSingleSmemElementSpaceSize =
|
||||
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKVector + kKPack>{},
|
||||
number<kNPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user