mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix FA bwd alibi+causal NaN errors (#1352)
* fix bwd alibi nan error
* fix datatype
---------
Co-authored-by: danyao12 <danyao12>
[ROCm/composable_kernel commit: 1da802bdf2]
This commit is contained in:
@@ -372,7 +372,7 @@ struct SimplifiedGenericAttentionMask
|
||||
// index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
|
||||
bool bottom_left_edge = i_y_end > (i_x + y);
|
||||
bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
|
||||
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge;
|
||||
|
||||
@@ -501,9 +501,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor()
|
||||
{
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType);
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kKPerBlock = [&]() {
|
||||
|
||||
Reference in New Issue
Block a user