mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 05:55:39 +00:00
Fix in bwd_piple_default_policy
This commit is contained in:
@@ -394,7 +394,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -853,7 +853,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
||||
|
||||
@@ -1808,14 +1808,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
// Compute
|
||||
static constexpr index_t Gemm0MFMA =
|
||||
kM0 * kN0 * kK0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm1MFMA =
|
||||
kN0 * kVHeaddimForGemmN * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kM0 * kN0 * kK2 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm3MFMA =
|
||||
kN0 * kQKHeaddimForGemmN * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
@@ -1838,13 +1836,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
kM0 * kQKHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
static constexpr index_t SGradT_LDS_READ_P1 =
|
||||
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t Q_LDS_READ =
|
||||
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
static constexpr index_t SGradT_LDS_READ_P2 =
|
||||
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t OGrad_LDS_READ =
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
|
||||
// LDS Write
|
||||
|
||||
Reference in New Issue
Block a user