diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index d0bf46653c..70d13118a0 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -3,7 +3,7 @@ # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { - "fp16" : "FmhaFwdFp16", + # "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", # "fp8" : "FmhaFwdFp8", # "fp8fp16": "FmhaFwdFp8Fp16", diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py index e2cc10226a..9e50543ed5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py @@ -31,6 +31,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = { 32 : 32, + 48 : 48, 64 : 64, 96 : 128, 128: 128, @@ -303,18 +304,18 @@ FMHA_FWD_DECODE_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_ // make sure we can reuse the padding flags in combine kernels static_assert({F_bm0} % kM0 == 0); - static_assert({F_bn1} % 32 == 0); + static_assert({F_bn1} % 16 == 0); if (t.has_lse) {{ if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ return -1; }} else {{ - using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; + using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/16, true, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_decode_(s, a); }} }} else {{ - using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; + using traits2_ = fmha_fwd_decode_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/16, false, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_decode_(s, a); }} @@ -649,17 +650,19 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '64': { + '48': { # # Specialize for different SeqQ - '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '64': FmhaFwdTileSize(64, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), - }, - '128': { - '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - '32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '64': FmhaFwdTileSize(64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + '16': FmhaFwdTileSize(16, 32, 48, 48, 32, 48, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 32, -1), + '32': FmhaFwdTileSize(32, 32, 48, 48, 32, 48, 1, 1, 1, 1, 1, 1, 32, 32, 16, 16, 16, 32, -1), }, + # '64': { + # '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # }, + # '128': { + # '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # '32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # }, } else: return None @@ -668,9 +671,10 @@ def get_fmha_fwd_decode_combine_tile_dict_from_dtype(dtype : str) -> Optional[di if dtype == 'fp16' or dtype == 'bf16': return { # '32' : FmhaFwdSplitKVCombineTileSize(32, -1), - '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + '48' : FmhaFwdSplitKVCombineTileSize(16, -1), + # '64' : FmhaFwdSplitKVCombineTileSize(16, -1), ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + # '128' : FmhaFwdSplitKVCombineTileSize(16, -1), # '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } else: @@ -692,7 +696,8 @@ def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> T if dtype in ['fp16', 'bf16']: for logits, mask, bias, pagedkv in itertools.product(["f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["f"]): for lse in ['t', 'f']: - if hdim in [64, 128]: ### [32, 64, 96, 128]: + # for lse in ['f']: + if hdim in [48, 64, 128]: ### [32, 64, 96, 128]: pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 't', 't', logits, bias, lse, squant, pagedkv, mask)) else: diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index be84842347..abc451b9ca 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -76,3 +76,7 @@ #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/unary_element_function.hpp" +// Use `CK_PRINT()` to inspect values of type T1, T2, ... +// Use `CK_PRINT()` to inspect constexpr values of val1, val2, ... of the same type +// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());` +// Set BUILD_DEV to OFF to avoid enabling Werror diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index 55e3274cde..5418375180 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -46,23 +46,22 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT return rh_major_minor_to_y_; }; - constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{}); - constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{}); + constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{}); - constexpr auto y_dim_out_to_in = [&] { - map y_dim_out_to_in_; - - for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out) - { - y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor]; - } - - return y_dim_out_to_in_; - }(); - - // constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); + using OutDstrEncode = typename decltype(out_tensor.get_tile_distribution())::DstrEncode; + // using InDstrEncode = typename decltype(in_tensor.get_tile_distribution())::DstrEncode; + + constexpr auto y_dim_out_to_in = generate_sequence_v2( + [&](auto i) constexpr { + constexpr index_t rh_major_out = OutDstrEncode::ys_to_rhs_major_[i]; + constexpr index_t rh_minor_out = OutDstrEncode::ys_to_rhs_minor_[i]; + + return number{}; + }, + number{}); + constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); // input and output vector dim in the order of input Y dims @@ -128,7 +127,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT // set output vectors static_for<0, num_vec_out, 1>{}([&](auto i) { - constexpr auto idx_y_out_tmp = generate_array( + constexpr auto idx_y_out_tmp = generate_tuple( [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, number{}); diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index fd0252d3ca..acd12b9a01 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -229,4 +229,13 @@ constexpr auto conditional_expr(X&& x, Y&& y) } } +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp index a95277f620..f65225ab43 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp @@ -676,6 +676,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); shuffle_tile(v_shuffle_tmp, v_prefetch); + // CK_PRINT(); store_tile( v_lds_window, tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp index ea499c4e9d..bff24dd47c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp @@ -48,18 +48,49 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() { + using QDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + constexpr index_t kMaxVecLoad = [&](){ + // Try dwordx4 + if constexpr (ElemPerThread % MaxVectorSize == 0){ + return MaxVectorSize; + } + // Try dwordx2 + else if constexpr (ElemPerThread % (MaxVectorSize / 2) == 0){ + return MaxVectorSize / 2; + } + // Try dword + else if constexpr (ElemPerThread % (MaxVectorSize / 4) == 0){ + return MaxVectorSize / 4; + } + else{ + return 1; + } + }(); constexpr index_t KPerThread = kMaxVecLoad; - constexpr index_t KThreads = kKPerBlock / KPerThread; + // if false, we can not distribute the Ps thread evenly over Hs. + constexpr bool KThreadEven = (ElemPerThread/KPerThread) % 2 == 0; + constexpr index_t KThreads = [&](){ + if constexpr (KThreadEven){ + return kKPerBlock / KPerThread; + } + else{ + // We move the odd factor to multiple instruction issue + return kKPerBlock/ElemPerThread; + } + }(); + + constexpr index_t KRepeat = kKPerBlock / KThreads / KPerThread; + static_assert(KRepeat == ElemPerThread/kMaxVecLoad); constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; constexpr index_t NumWarps = kBlockSize / get_warp_size(); constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); @@ -67,11 +98,154 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, tuple, - sequence>, + sequence>, tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + constexpr index_t kMaxVecLoad = [&](){ + // Try dwordx4 + if constexpr (ElemPerThread % MaxVectorSize == 0){ + return MaxVectorSize; + } + // Try dwordx2 + else if constexpr (ElemPerThread % (MaxVectorSize / 2) == 0){ + return MaxVectorSize / 2; + } + // Try dword + else if constexpr (ElemPerThread % (MaxVectorSize / 4) == 0){ + return MaxVectorSize / 4; + } + else{ + return 1; + } + }(); + + constexpr index_t KPerThread = kMaxVecLoad; + // if false, we can not distribute the Ps thread evenly over Hs. + constexpr bool KThreadEven = (ElemPerThread/KPerThread) % 2 == 0; + constexpr index_t KThreads = [&](){ + if constexpr (KThreadEven){ + return kKPerBlock / KPerThread; + } + else{ + // We move the odd factor to multiple instruction issue + return kKPerBlock/ElemPerThread; + }}(); + + constexpr index_t KRepeat = kKPerBlock / KThreads / KPerThread; + static_assert(KRepeat == ElemPerThread/kMaxVecLoad); + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + // constexpr index_t kBlockSize = Problem::kBlockSize; + // constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + // constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N2 = 4; // Y 4 + constexpr index_t N1 = 4; // P 2 + constexpr index_t N0 = 3; // Y 2 + + constexpr index_t K2 = 2; // Y + constexpr index_t K1 = 16; // P + constexpr index_t K0 = 1; // P + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 1>>, + sequence<1, 2, 1>, + sequence<0, 2, 2>>{}); +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N1 = GetAlignmentV(); // 8 + constexpr index_t N0 = kNPerBlock / N1; // P // 2 + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor() + { + constexpr index_t N2 = 4; // Y 4 + constexpr index_t N1 = 4; // P 2 + constexpr index_t N0 = 3; // Y 2 + + constexpr index_t K2 = 2; // Y + constexpr index_t K1 = 16; // P + constexpr index_t K0 = 1; // P + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 1>>, + sequence<1, 1, 2>, + sequence<0, 2, 2>>{}); } template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 30d07a4754..b4ce5f0468 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -73,6 +73,7 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); if constexpr(std::is_same_v && @@ -81,8 +82,12 @@ struct BlockFmhaPipelineQXCustomPolicy { if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + else if constexpr(WarpGemmM == 16){ + if constexpr(WarpGemmK == 32) + return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}; + else + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + } else // WarpGemmM == 4 return WarpGemmMfmaF16F16F32M4N64K16{}; } @@ -92,8 +97,12 @@ struct BlockFmhaPipelineQXCustomPolicy { if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; - else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + else if constexpr(WarpGemmM == 16){ + if constexpr(WarpGemmK == 32) + return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}; + else + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + } else // WarpGemmM == 4 return WarpGemmMfmaBf16Bf16F32M4N64K16{}; } @@ -239,7 +248,7 @@ struct BlockFmhaPipelineQXCustomPolicy if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}; else // WarpGemmM == 4 return WarpGemmMfmaF16F16F32M4N64K16{}; } @@ -250,7 +259,7 @@ struct BlockFmhaPipelineQXCustomPolicy if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}; else // WarpGemmM == 4 return WarpGemmMfmaBf16Bf16F32M4N64K16{}; } diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 76ba34115f..d54254a2fe 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -9,6 +9,8 @@ namespace ck_tile { static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len) { + if(len == 48) + return 48; if(len == 96) return 128; if(len == 160)