bm0=32 not correct, support hdim48, performance not good

This commit is contained in:
aska-0096
2025-07-11 06:55:02 +00:00
parent 18669925cc
commit c69d450b22
9 changed files with 248 additions and 45 deletions

View File

@@ -3,7 +3,7 @@
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
# "fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
# "fp8" : "FmhaFwdFp8",
# "fp8fp16": "FmhaFwdFp8Fp16",

View File

@@ -31,6 +31,7 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {
32 : 32,
48 : 48,
64 : 64,
96 : 128,
128: 128,
@@ -303,18 +304,18 @@ FMHA_FWD_DECODE_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
static_assert({F_bn1} % 16 == 0);
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/16, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_decode_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/16, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_decode_<traits_, traits2_>(s, a);
}}
@@ -649,17 +650,19 @@ class FmhaFwdSplitKVCombineKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'64': {
'48': {
# # Specialize for different SeqQ
'16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
'32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64': FmhaFwdTileSize(64, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
},
'128': {
'16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
'32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64': FmhaFwdTileSize(64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
'16': FmhaFwdTileSize(16, 32, 48, 48, 32, 48, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 32, -1),
'32': FmhaFwdTileSize(32, 32, 48, 48, 32, 48, 1, 1, 1, 1, 1, 1, 32, 32, 16, 16, 16, 32, -1),
},
# '64': {
# '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# },
# '128': {
# '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# },
}
else:
return None
@@ -668,9 +671,10 @@ def get_fmha_fwd_decode_combine_tile_dict_from_dtype(dtype : str) -> Optional[di
if dtype == 'fp16' or dtype == 'bf16':
return {
# '32' : FmhaFwdSplitKVCombineTileSize(32, -1),
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'48' : FmhaFwdSplitKVCombineTileSize(16, -1),
# '64' : FmhaFwdSplitKVCombineTileSize(16, -1),
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
# '128' : FmhaFwdSplitKVCombineTileSize(16, -1),
# '256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
else:
@@ -692,7 +696,8 @@ def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> T
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, pagedkv in itertools.product(["f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["f"]):
for lse in ['t', 'f']:
if hdim in [64, 128]: ### [32, 64, 96, 128]:
# for lse in ['f']:
if hdim in [48, 64, 128]: ### [32, 64, 96, 128]:
pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 't', 't', logits, bias, lse, squant, pagedkv, mask))
else:

View File

@@ -76,3 +76,7 @@
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/unary_element_function.hpp"
// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
// Set BUILD_DEV to OFF to avoid enabling Werror

View File

@@ -46,23 +46,22 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
return rh_major_minor_to_y_;
};
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
}
return y_dim_out_to_in_;
}();
//
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
using OutDstrEncode = typename decltype(out_tensor.get_tile_distribution())::DstrEncode;
// using InDstrEncode = typename decltype(in_tensor.get_tile_distribution())::DstrEncode;
constexpr auto y_dim_out_to_in = generate_sequence_v2(
[&](auto i) constexpr {
constexpr index_t rh_major_out = OutDstrEncode::ys_to_rhs_major_[i];
constexpr index_t rh_minor_out = OutDstrEncode::ys_to_rhs_minor_[i];
return number<rh_major_minor_to_y_in[{rh_major_out, rh_minor_out}]>{};
},
number<NDimY>{});
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
@@ -128,7 +127,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
constexpr auto idx_y_out_tmp = generate_tuple(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});

View File

@@ -229,4 +229,13 @@ constexpr auto conditional_expr(X&& x, Y&& y)
}
}
template <auto... val>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
template <typename... type>
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
{
}
} // namespace ck_tile

View File

@@ -676,6 +676,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
// CK_PRINT<decltype(v_shuffle_tmp), decltype(v_prefetch)>();
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch

View File

@@ -48,18 +48,49 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMaxVecLoad = [&](){
// Try dwordx4
if constexpr (ElemPerThread % MaxVectorSize == 0){
return MaxVectorSize;
}
// Try dwordx2
else if constexpr (ElemPerThread % (MaxVectorSize / 2) == 0){
return MaxVectorSize / 2;
}
// Try dword
else if constexpr (ElemPerThread % (MaxVectorSize / 4) == 0){
return MaxVectorSize / 4;
}
else{
return 1;
}
}();
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
// if false, we can not distribute the Ps thread evenly over Hs.
constexpr bool KThreadEven = (ElemPerThread/KPerThread) % 2 == 0;
constexpr index_t KThreads = [&](){
if constexpr (KThreadEven){
return kKPerBlock / KPerThread;
}
else{
// We move the odd factor to multiple instruction issue
return kKPerBlock/ElemPerThread;
}
}();
constexpr index_t KRepeat = kKPerBlock / KThreads / KPerThread;
static_assert(KRepeat == ElemPerThread/kMaxVecLoad);
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
@@ -67,11 +98,154 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
sequence<KRepeat, KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t kMaxVecLoad = [&](){
// Try dwordx4
if constexpr (ElemPerThread % MaxVectorSize == 0){
return MaxVectorSize;
}
// Try dwordx2
else if constexpr (ElemPerThread % (MaxVectorSize / 2) == 0){
return MaxVectorSize / 2;
}
// Try dword
else if constexpr (ElemPerThread % (MaxVectorSize / 4) == 0){
return MaxVectorSize / 4;
}
else{
return 1;
}
}();
constexpr index_t KPerThread = kMaxVecLoad;
// if false, we can not distribute the Ps thread evenly over Hs.
constexpr bool KThreadEven = (ElemPerThread/KPerThread) % 2 == 0;
constexpr index_t KThreads = [&](){
if constexpr (KThreadEven){
return kKPerBlock / KPerThread;
}
else{
// We move the odd factor to multiple instruction issue
return kKPerBlock/ElemPerThread;
}}();
constexpr index_t KRepeat = kKPerBlock / KThreads / KPerThread;
static_assert(KRepeat == ElemPerThread/kMaxVecLoad);
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KRepeat, KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
// constexpr index_t kBlockSize = Problem::kBlockSize;
// constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N2 = 4; // Y 4
constexpr index_t N1 = 4; // P 2
constexpr index_t N0 = 3; // Y 2
constexpr index_t K2 = 2; // Y
constexpr index_t K1 = 16; // P
constexpr index_t K0 = 1; // P
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 2, 1>,
sequence<0, 2, 2>>{});
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
constexpr index_t N0 = kNPerBlock / N1; // P // 2
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor()
{
constexpr index_t N2 = 4; // Y 4
constexpr index_t N1 = 4; // P 2
constexpr index_t N0 = 3; // Y 2
constexpr index_t K2 = 2; // Y
constexpr index_t K1 = 16; // P
constexpr index_t K0 = 1; // P
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 1, 2>,
sequence<0, 2, 2>>{});
}
template <typename Problem>

View File

@@ -73,6 +73,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
@@ -81,8 +82,12 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else if constexpr(WarpGemmM == 16){
if constexpr(WarpGemmK == 32)
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
else
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
}
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
@@ -92,8 +97,12 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else if constexpr(WarpGemmM == 16){
if constexpr(WarpGemmK == 32)
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
else
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
}
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
@@ -239,7 +248,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
@@ -250,7 +259,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}

View File

@@ -9,6 +9,8 @@ namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
{
if(len == 48)
return 48;
if(len == 96)
return 128;
if(len == 160)