[CK_TILE][FMHA][Feature] Add support for large hdim

* root cause: fhma_bwd not support if hdim > 256 due to the use of LDS goes beyond the hardware limitations.

* solution: 1. split dqdkdv kernel into 2 kernels.
*              1) QGrad
*              2) KGrad & VGrad
*           2. reuse LDS memory.
*              1). K and K^T use same LDS memory in dq kernel
*              2). OGrad and OGrad^T use same LDS memory in dq kernel
*           3. to avoid or reduce the number of VGPR spills, the calculation order has been readjusted, and prefetch has been disabled.
This commit is contained in:
jian.wu
2025-08-12 10:53:44 +08:00
parent 1824d65758
commit 2bbff45dcb
6 changed files with 1236 additions and 47 deletions

View File

@@ -712,7 +712,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
constexpr index_t K1 = 16 / sizeof(AccDataType);
constexpr index_t K1 = 32 / sizeof(AccDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
@@ -1930,13 +1930,44 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds);
if constexpr (Problem::BlockFmhaShape::kQKHeaddim > 256 && Problem::BlockFmhaShape::kVHeaddim > 256)
{
// kernel0: dq
// LDS layout
// | leading stage | leading stage | loop stage
// | K(K^T) | V | Q
// | | | OGrad
// | | | LSE
// | | | D
// | | | Bias
// | | | SGrad
// kernel1: dk & dv
// LDS layout
// | leading stage | leading stage | loop stage
// | K | V | Q
// | | | Q^T
// | | | OGrad(OGrad^T)
// | | | LSE
// | | | D
// | | | Bias
//
// Note:
// A(B) mean A and B use same LDS
return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
constexpr index_t smem_size_kernel0 = max(max(smem_size_k, smem_size_v), smem_size_q + smem_size_do + smem_size_lse + smem_size_d + smem_size_bias + smem_size_ds);
constexpr index_t smem_size_kernel1 = max(max(smem_size_k, smem_size_v), smem_size_q + smem_size_qt + smem_size_do + smem_size_lse + smem_size_d + smem_size_bias);
return max(smem_size_kernel0, smem_size_kernel1);
}
else
{
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds);
return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
}
}
template <typename Problem_>