mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Fix Int32 Overflow in Deterministic FMHA BWD (#3615)
This commit is contained in:
@@ -171,7 +171,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::long_index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
};
|
||||
@@ -294,7 +294,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::long_index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
};
|
||||
@@ -377,7 +377,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::long_index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
@@ -388,7 +388,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::long_index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dk,
|
||||
ck_tile::index_t batch_stride_dv,
|
||||
ck_tile::index_t batch_stride_dbias,
|
||||
@@ -549,7 +549,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::long_index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
@@ -1574,7 +1574,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::long_index_t nhead_stride_dq_acc;
|
||||
};
|
||||
|
||||
struct FmhaBwdConvertQGradDeterministicKargs
|
||||
@@ -1589,7 +1589,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
FmhaBwdConvertQGradEmptyKargs<0>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::long_index_t batch_stride_dq_acc;
|
||||
};
|
||||
|
||||
struct FmhaBwdConvertQGradGroupModeKargs
|
||||
@@ -1620,9 +1620,9 @@ struct FmhaBwdConvertQGradKernel
|
||||
ck_tile::index_t stride_dq,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dq,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::long_index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dq,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::long_index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t split_stride_dq_acc)
|
||||
{
|
||||
Kargs kargs{{dq_acc_ptr,
|
||||
@@ -1660,7 +1660,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
ck_tile::index_t stride_dq,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dq,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::long_index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t split_stride_dq_acc)
|
||||
{
|
||||
Kargs kargs{{dq_acc_ptr,
|
||||
|
||||
Reference in New Issue
Block a user