[CK_TILE] FMHA BWD Fix Decode Accuracy (#2881)

* [CK_TILE] FMHA BWD Fix Decode Accuracy

* use s_waitcnt utils
This commit is contained in:
Yi DING
2025-09-19 21:45:02 +08:00
committed by GitHub
parent 86dd59cd01
commit 6cf3fdd21c

View File

@@ -489,7 +489,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
move_tile_window(k_dram_window, {kN0, 0});
async_load_tile(v_lds_write_window, v_dram_window);
move_tile_window(v_dram_window, {kN0, 0});
// __builtin_amdgcn_s_waitcnt(0);
s_waitcnt</*vmcnt=*/0>();
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
@@ -636,7 +636,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
}
}();
store_tile(bias_lds_write_window, dbias);
__builtin_amdgcn_s_waitcnt(3952);
s_waitcnt</*vmcnt=*/0>();
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
@@ -664,7 +664,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
}
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);
s_waitcnt</*vmcnt=*/0>();
block_sync_lds();
if constexpr(is_epilogue)
{