mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK_TILE] Fix compiler related FA bwd issues (#1530)
* add barriers
* tail bias barriers
* adjust bf16/hd256 tol
* continue adjust bf16/hd256 tol
[ROCm/composable_kernel commit: 9d69a099a4]
This commit is contained in:
@@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[])
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(int /*init_method*/)
|
||||
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
|
||||
{
|
||||
rtol = 3.2e-2;
|
||||
atol = 3.2e-2;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
@@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
|
||||
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
|
||||
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
|
||||
Reference in New Issue
Block a user