mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] Fix compiler related FA bwd issues (#1530)
* add barriers * tail bias barriers * adjust bf16/hd256 tol * continue adjust bf16/hd256 tol
This commit is contained in:
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
|
||||
Reference in New Issue
Block a user