diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 131729992b..8a13c0b060 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), - Policy::template MakeKRegSliceBlockDescriptor()); + Policy::template MakeKRegBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( Policy::template MakeKRegBlockDescriptor()); @@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), - Policy::template MakeVRegSliceBlockDescriptor()); - - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); + Policy::template MakeVRegBlockDescriptor()); //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg @@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); @@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); block_sync_lds(); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS @@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); @@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); @@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; - static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 3156e4a356..d1b6e6f85b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), - Policy::template MakeKRegSliceBlockDescriptor()); + Policy::template MakeKRegBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( Policy::template MakeKRegBlockDescriptor()); @@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), - Policy::template MakeVRegSliceBlockDescriptor()); - - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); + Policy::template MakeVRegBlockDescriptor()); //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg @@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); @@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS auto q_dram_window = @@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); @@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); @@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; - static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 0afad0446c..d353203e0e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using QDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); @@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); @@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; @@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using OGradDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); @@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; return total_pixels / GetAlignmentK(); @@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentV(); constexpr index_t K0 = kKPerBlock / K1; @@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); return MakeXLdsBlockDescriptor(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto k_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - - return k_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() { @@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kVPack = GetSmemKPackV(); return MakeXLdsBlockDescriptor(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto v_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); - - return v_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() { @@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); @@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { // Hold full block data constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPack = GetSmemKPackOGrad(); @@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0; + static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2; static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; static constexpr index_t WarpGemmM = @@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy // Compute static constexpr index_t Gemm0MFMA = - kM0 * kN0 * kQKHeaddim / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm1MFMA = - kM0 * kN0 * kVHeaddim / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); - static constexpr index_t Gemm2MFMA = kN0 * kVHeaddim * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm2MFMA = + kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm3MFMA = kN0 * kQKHeaddim * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); @@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); static constexpr index_t SGradT_LDS_READ_P1 = kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); - static constexpr index_t Q_LDS_READ = - kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ(); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t SGradT_LDS_READ_P2 = kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); static constexpr index_t OGrad_LDS_READ = - kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); // LDS Write