[CK_TILE] Fix Int32 Overflow in Deterministic FMHA BWD (#3615)

This commit is contained in:
Yi DING
2026-01-21 09:54:46 +08:00
committed by GitHub
parent d5ae81b292
commit fcc9372c00
3 changed files with 20 additions and 20 deletions

View File

@@ -171,7 +171,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::long_index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
};
@@ -294,7 +294,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::long_index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
};
@@ -377,7 +377,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
@@ -388,7 +388,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::long_index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
@@ -549,7 +549,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
@@ -1574,7 +1574,7 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::long_index_t nhead_stride_dq_acc;
};
struct FmhaBwdConvertQGradDeterministicKargs
@@ -1589,7 +1589,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs<0>>
{
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::long_index_t batch_stride_dq_acc;
};
struct FmhaBwdConvertQGradGroupModeKargs
@@ -1620,9 +1620,9 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::long_index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
@@ -1660,7 +1660,7 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,