From 36f241814e246b61cfeee798df4053f3fea9bd7d Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 12 Oct 2024 15:30:54 +0000 Subject: [PATCH] Fix in bwd_piple_default_policy --- .../block_fmha_bwd_pipeline_default_policy.hpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index b62dc2def4..c620753817 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -394,7 +394,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t K1 = GetAlignmentV(); constexpr index_t K0 = kKPerBlock / K1; @@ -853,7 +853,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t kVPack = GetSmemKPackV(); @@ -1808,14 +1808,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy // Compute static constexpr index_t Gemm0MFMA = - kM0 * kN0 * kK0 / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm1MFMA = kN0 * kVHeaddimForGemmN * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm2MFMA = - kM0 * kN0 * kK2 / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm3MFMA = kN0 * kQKHeaddimForGemmN * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); @@ -1838,13 +1836,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy kM0 * kQKHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentQ(); static constexpr index_t SGradT_LDS_READ_P1 = kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); - static constexpr index_t Q_LDS_READ = - kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ(); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t SGradT_LDS_READ_P2 = kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); static constexpr index_t OGrad_LDS_READ = - kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); // LDS Write