mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
more strides for fa integration
This commit is contained in:
@@ -496,6 +496,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
@@ -508,6 +509,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_q, // nhead_stride_dq_acc
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -518,6 +522,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_q, // batch_stride_dq_acc
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
|
||||
@@ -99,6 +99,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
@@ -111,6 +112,9 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -121,6 +125,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
@@ -179,6 +184,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_lsed,
|
||||
args.split_stride_dq_acc,
|
||||
@@ -227,6 +234,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
@@ -307,9 +316,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
@@ -320,11 +329,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
|
||||
@@ -147,6 +147,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
};
|
||||
@@ -301,6 +303,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
@@ -350,6 +354,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
batch_stride_lsed}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dbias
|
||||
@@ -452,6 +458,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
@@ -491,6 +499,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
batch_stride_lsed}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dbias
|
||||
@@ -687,10 +697,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
|
||||
batch_offset_do;
|
||||
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
|
||||
batch_offset_dk;
|
||||
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_v +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
|
||||
batch_offset_dv;
|
||||
|
||||
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
|
||||
|
||||
Reference in New Issue
Block a user