From a592107cb97edfa59866a376e2fe20aea86f3e06 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 29 Jul 2025 09:31:14 +0800 Subject: [PATCH] [CK_TILE] FMHA bwd Support hdim as a Multiple of 32 (#2130) * Fix shuffle_tile * Add fmha bwd d160 * CHANGELOG * Use static_cast * Update --------- Co-authored-by: asleepzzz [ROCm/composable_kernel commit: 1926cd0cb8bfb0139f29a518ebfb5368920d5e4b] --- CHANGELOG.md | 1 + .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 2 + .../tensor/tile_distribution_encoding.hpp | 22 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 6 +- ...block_fmha_bwd_pipeline_default_policy.hpp | 556 +++++++++++++----- 5 files changed, 446 insertions(+), 141 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa3ba71143..4c054b822a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. * Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd) * Added benchmarking support for tile engine GEMM. * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index e4f46b502a..77b63a0c83 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -357,6 +357,8 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict "kr_ktr_vr_iglp", "kr_ktr_vr"], '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], + # '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"], '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"] } diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp index 52a16f32bd..b380e7c9d8 100644 --- a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -533,6 +533,26 @@ struct tile_distribution_encoding } }; +template +class tile_distribution_encoding_shuffle; +template +class tile_distribution_encoding_shuffle> +{ + template + using shuffled = sequence<(Ys2RHs::template get())...>; + + public: + using type = tile_distribution_encoding, + shuffled>; +}; +template +using tile_distribution_encoding_shuffle_t = + typename tile_distribution_encoding_shuffle::type; + namespace detail { template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 420ae03b7e..c88b058d32 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -182,7 +182,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), Policy::template MakeKRegBlockDescriptor()); @@ -208,7 +208,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), Policy::template MakeVRegBlockDescriptor()); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index d353203e0e..bc0dc592f0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,13 @@ namespace ck_tile { struct BlockFmhaBwdPipelineDefaultPolicy { + template + static constexpr auto swap_last2 = generate_sequence_v2( + [](auto i) { + return number < i == ndim - 2 ? ndim - 1 : i == ndim - 1 ? ndim - 2 : i > {}; + }, + number{}); + template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { @@ -384,13 +391,40 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t N0 = kBlockSize / get_warp_size(); constexpr index_t N2 = kNPerBlock / (N1 * N0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock) + { + return dstr; + } + else + { + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t N1_m = get_warp_size() / K1_m; + constexpr index_t N2_m = kNPerBlock / (N1_m * N0); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // N0, N1 K1 + tuple, sequence<1, 1>>, + sequence<2, 1, 2>, // K0 N2 K2 + sequence<0, 2, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -407,13 +441,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kNPerBlock / (N2 * N1); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, - tuple, sequence<1, 2>>, + tuple, sequence<1, 2>>, // N1, N2 K0 tuple, sequence<2, 0>>, - sequence<1, 2>, + sequence<1, 2>, // N0 K1 sequence<0, 1>>{}); + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock) + { + return dstr; + } + else + { + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t N2_m = get_warp_size() / K1_m; + constexpr index_t N0_m = kNPerBlock / (N2_m * N1); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // N1, N2 K1 + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, // K0 N0 K2 + sequence<0, 0, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -430,13 +490,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M0 = kBlockSize / get_warp_size(); constexpr index_t M2 = kMPerBlock / (M1 * M0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock) + { + return dstr; + } + else + { + // something not divisible, try a more flexible distribution + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t M1_m = get_warp_size() / K1_m; + constexpr index_t M2_m = kMPerBlock / (M1_m * M0); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // M0, M1 K1 + tuple, sequence<1, 1>>, + sequence<2, 1, 2>, // K0 M2 K2 + sequence<0, 2, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -453,13 +541,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M0 = kBlockSize / get_warp_size(); constexpr index_t M2 = kMPerBlock / (M1 * M0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock) + { + return dstr; + } + else + { + // something not divisible, try a more flexible distribution + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t M1_m = get_warp_size() / K1_m; + constexpr index_t M2_m = kMPerBlock / (M1_m * M0); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // M0, M1 K1 + tuple, sequence<1, 1>>, + sequence<2, 1, 2>, // K0 M2 K2 + sequence<0, 2, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -504,13 +620,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M0 = kBlockSize / get_warp_size(); constexpr index_t M2 = kMPerBlock / (M1 * M0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kNPerBlock); + return dstr; } template @@ -522,13 +641,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = get_warp_size(); constexpr index_t M0 = MPerBlock / M1; - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1>>, tuple, sequence<1>>, sequence<1, 2, 2>, sequence<2, 0, 1>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + MPerBlock * KPerBlock); + return dstr; } template @@ -569,13 +691,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = kMPerBlock / (M1 * M2); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence, sequence>, tuple, sequence<2, 3>>, tuple, sequence<2, 0>>, sequence<1, 2, 3>, sequence<0, 0, 1>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr; } template @@ -594,13 +719,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = kMPerBlock / (M1 * M2); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); + + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr; } // these are for lds @@ -666,56 +795,80 @@ struct BlockFmhaBwdPipelineDefaultPolicy return 16 / sizeof(GemmDataType); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() { constexpr auto DataTypeSize = 2; // sizeof(F16/BF16) constexpr auto MNLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + (32 * 4 / KPerSubBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerSubBlock / DataTypeSize); - constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto x_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor( x_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{})); constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( x_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), + make_tuple(make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), make_pass_through_transform(number{}), make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); constexpr auto x_lds_block_desc = transform_tensor_descriptor( x_lds_block_desc_xk0_mnldslayer_mn_xk1, make_tuple(make_merge_transform_v3_division_mod( make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<2, 3>{}, sequence<0, 1, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); + static_assert(container_reduce(x_lds_block_desc.get_lengths(), + std::multiplies{}, + 1) == KIter * MNPerBlock * KPerSubBlock); return x_lds_block_desc; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() + { + return MakeXLdsBlockDescriptor<1, MNPerBlock, KPerBlock, KPack>(); + } template CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() + { + return MakeXTLdsBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() { // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset @@ -723,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); constexpr auto kBlockSize = Problem::kBlockSize; - constexpr auto MN0 = MNPerBlock / KPack; + constexpr auto MN0 = MNPerSubBlock / KPack; constexpr auto MN1 = KPack; constexpr auto KThreadWrite = kBlockSize / MN0; @@ -745,13 +898,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy : ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2)); constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor( - make_tuple(number{}, + make_tuple(number{}, + number{}, number{}, number{}, number{}, number{}, KPackT), - make_tuple(number{}, + make_tuple(number{}, + number{}, number{}, number{}, number{}, @@ -763,20 +918,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor( xt_lds_block_desc_raw, make_tuple( + make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_xor_transform( make_tuple(number{}, number{})), make_pass_through_transform(number{}), make_pass_through_transform(KPackT)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3, 4>{}, + sequence<5>{}, + sequence<6>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3, 4>{}, + sequence<5>{}, + sequence<6>{})); constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor( xt_lds_block_desc_permuted, make_tuple( + make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_unmerge_transform(make_tuple(number{}, number{})), @@ -788,27 +953,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy sequence<2>{}, sequence<3>{}, sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, + sequence<5>{}, + sequence<6>{}), + make_tuple(sequence<0>{}, sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); + sequence<3>{}, + sequence<1, 4>{}, + sequence<5, 6>{}, + sequence<7>{}, + sequence<8>{})); constexpr auto xt_lds_block_desc = transform_tensor_descriptor( xt_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}, number{}))), + make_tuple(sequence<1, 2, 5, 3, 8>{}, sequence<0, 6, 7, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); - + static_assert(container_reduce(xt_lds_block_desc.get_lengths(), + std::multiplies{}, + 1) == MNPerSubBlock * MNIter * KPerBlock); return xt_lds_block_desc; } @@ -817,9 +987,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPack = GetSmemKPackK(); - return MakeXLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeKDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackK(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -850,7 +1035,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - + static_assert(container_reduce(k_block_dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); return k_block_dstr; } @@ -860,9 +1046,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kVPack = GetSmemKPackV(); - - return MakeXLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeVDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kVPack = GetSmemKPackV(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kVPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -893,30 +1093,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); - + static_assert(container_reduce(v_block_dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); return v_block_dstr; } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t K1 = GetAlignmentK(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentK(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + using dram_encoding = typename decltype(MakeKDramTileDistribution())::DstrEncode; + constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + static_assert(y_ndim >= 2); + using shuffled_encoding_t = + tile_distribution_encoding_shuffle_t)>>; + return make_static_tile_distribution(shuffled_encoding_t{}); } template @@ -926,10 +1117,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPack = GetSmemKPackK(); - constexpr index_t kKPackT = GetSmemKPackKT(); - - return MakeXTLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeKDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKPackT = GetSmemKPackKT(); + return MakeXTLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2); + return MakeXTLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -976,7 +1187,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode); - + static_assert(container_reduce(kt_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); return kt_block_dstr; } @@ -986,9 +1199,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPack = GetSmemKPackQ(); - - return MakeXLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeQDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackQ(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1019,30 +1246,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - + static_assert(container_reduce(q_block_dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); return q_block_dstr; } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t K1 = GetAlignmentQ(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentQ(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + using dram_encoding = typename decltype(MakeQDramTileDistribution())::DstrEncode; + constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + static_assert(y_ndim >= 2); + using shuffled_encoding_t = + tile_distribution_encoding_shuffle_t)>>; + return make_static_tile_distribution(shuffled_encoding_t{}); } template @@ -1052,10 +1270,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPack = GetSmemKPackQ(); - constexpr index_t kKPackT = GetSmemKPackQT(); - - return MakeXTLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeQDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kKPackT = GetSmemKPackQT(); + return MakeXTLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2); + return MakeXTLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1103,6 +1341,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode); + static_assert(container_reduce(qt_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); return qt_block_dstr; } @@ -1135,7 +1376,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode); - + static_assert(container_reduce(dst_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return dst_block_dstr; } @@ -1177,13 +1420,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = MWarp; constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple>, tuple, sequence<1, 0>>, tuple, sequence<3, 1>>, sequence<1, 1, 1>, sequence<0, 2, 4>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock); + return dstr; } template @@ -1193,9 +1439,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPack = GetSmemKPackOGrad(); - - return MakeXLdsBlockDescriptor(); + using dram_encoding = + typename decltype(MakeOGradDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackOGrad(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1226,30 +1487,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode); - + static_assert(container_reduce(do_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return do_block_dstr; } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - - constexpr index_t K1 = GetAlignmentOGrad(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentOGrad(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + using dram_encoding = + typename decltype(MakeOGradDramTileDistribution())::DstrEncode; + constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + static_assert(y_ndim >= 2); + using shuffled_encoding_t = + tile_distribution_encoding_shuffle_t)>>; + return make_static_tile_distribution(shuffled_encoding_t{}); } template @@ -1259,10 +1514,31 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPack = GetSmemKPackOGrad(); - constexpr index_t kKPackT = GetSmemKPackOGradT(); - - return MakeXTLdsBlockDescriptor(); + using dram_encoding = + typename decltype(MakeOGradDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t kKPackT = GetSmemKPackOGradT(); + return MakeXTLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2); + return MakeXTLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1310,7 +1586,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode); - + static_assert(container_reduce(dot_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); return dot_block_dstr; } @@ -1342,7 +1620,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode); - + static_assert(container_reduce(pt_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return pt_block_dstr; } @@ -1384,7 +1664,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode); - + static_assert(container_reduce(ds_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return ds_block_dstr; }