mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[CK_TILE][FMHA][Feature] Add support for large hdim
* root cause: fhma_bwd not support if hdim > 256 due to the use of LDS goes beyond the hardware limitations. * solution: 1. split dqdkdv kernel into 2 kernels. * 1) QGrad * 2) KGrad & VGrad * 2. reuse LDS memory. * 1). K and K^T use same LDS memory in dq kernel * 2). OGrad and OGrad^T use same LDS memory in dq kernel * 3. to avoid or reduce the number of VGPR spills, the calculation order has been readjusted, and prefetch has been disabled.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -712,7 +712,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(AccDataType);
|
||||
constexpr index_t K1 = 32 / sizeof(AccDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
@@ -1930,13 +1930,44 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
|
||||
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
|
||||
|
||||
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
|
||||
constexpr index_t smem_size_stage0_1 = smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
|
||||
smem_size_do + smem_size_lse + smem_size_d +
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
if constexpr (Problem::BlockFmhaShape::kQKHeaddim > 256 && Problem::BlockFmhaShape::kVHeaddim > 256)
|
||||
{
|
||||
// kernel0: dq
|
||||
// LDS layout
|
||||
// | leading stage | leading stage | loop stage
|
||||
// | K(K^T) | V | Q
|
||||
// | | | OGrad
|
||||
// | | | LSE
|
||||
// | | | D
|
||||
// | | | Bias
|
||||
// | | | SGrad
|
||||
// kernel1: dk & dv
|
||||
// LDS layout
|
||||
// | leading stage | leading stage | loop stage
|
||||
// | K | V | Q
|
||||
// | | | Q^T
|
||||
// | | | OGrad(OGrad^T)
|
||||
// | | | LSE
|
||||
// | | | D
|
||||
// | | | Bias
|
||||
//
|
||||
// Note:
|
||||
// A(B) mean A and B use same LDS
|
||||
|
||||
return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
|
||||
constexpr index_t smem_size_kernel0 = max(max(smem_size_k, smem_size_v), smem_size_q + smem_size_do + smem_size_lse + smem_size_d + smem_size_bias + smem_size_ds);
|
||||
constexpr index_t smem_size_kernel1 = max(max(smem_size_k, smem_size_v), smem_size_q + smem_size_qt + smem_size_do + smem_size_lse + smem_size_d + smem_size_bias);
|
||||
return max(smem_size_kernel0, smem_size_kernel1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
|
||||
constexpr index_t smem_size_stage0_1 = smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
|
||||
smem_size_do + smem_size_lse + smem_size_d +
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
|
||||
return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem_>
|
||||
|
||||
Reference in New Issue
Block a user