diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 49681015a9..626e20cde2 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -496,6 +496,7 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_randval, stride_do, stride_q, // stride_dq_acc + stride_q, // stride_dq stride_dk, stride_dv, stride_dbias, @@ -508,6 +509,9 @@ bool run(const ck_tile::ArgParser& arg_parser) nhead_stride_do, nhead_stride_lsed, nhead_stride_q, // nhead_stride_dq_acc + nhead_stride_q, // nhead_stride_dq + nhead_stride_k, // nhead_stride_dk + nhead_stride_v, // nhead_stride_dv nhead_stride_dbias, batch_stride_q, batch_stride_k, @@ -518,6 +522,7 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_do, batch_stride_lsed, batch_stride_q, // batch_stride_dq_acc + batch_stride_q, // batch_stride_dq batch_stride_dk, batch_stride_dv, batch_stride_dbias, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 2f5af175d7..bb28034982 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -99,6 +99,7 @@ struct fmha_bwd_args ck_tile::index_t stride_randval; ck_tile::index_t stride_do; ck_tile::index_t stride_dq_acc; + ck_tile::index_t stride_dq; ck_tile::index_t stride_dk; ck_tile::index_t stride_dv; ck_tile::index_t stride_dbias; @@ -111,6 +112,9 @@ struct fmha_bwd_args ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; ck_tile::index_t nhead_stride_dbias; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -121,6 +125,7 @@ struct fmha_bwd_args ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; ck_tile::index_t batch_stride_dq_acc; + ck_tile::index_t batch_stride_dq; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dbias; @@ -179,6 +184,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_do, args.nhead_stride_lsed, args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, args.nhead_stride_dbias, args.batch_stride_lsed, args.split_stride_dq_acc, @@ -227,6 +234,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_do, args.nhead_stride_lsed, args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, args.nhead_stride_dbias, args.batch_stride_q, args.batch_stride_k, @@ -307,9 +316,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.seqstart_q_ptr, args.seqstart_k_ptr, args.hdim_q, - args.stride_q, + args.stride_dq, args.stride_dq_acc, - args.nhead_stride_q, + args.nhead_stride_dq, args.nhead_stride_dq_acc, args.split_stride_dq_acc); } @@ -320,11 +329,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.seqlen_q, args.seqlen_k, args.hdim_q, - args.stride_q, + args.stride_dq, args.stride_dq_acc, - args.nhead_stride_q, + args.nhead_stride_dq, args.nhead_stride_dq_acc, - args.batch_stride_q, + args.batch_stride_dq, args.batch_stride_dq_acc, args.split_stride_dq_acc); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 1c6401d493..6e3983a90f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -147,6 +147,8 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; ck_tile::index_t batch_stride_lsed; }; @@ -301,6 +303,8 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dk, + ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, @@ -350,6 +354,8 @@ struct FmhaBwdDQDKDVKernel nhead_stride_do, nhead_stride_lsed, nhead_stride_dq_acc, + nhead_stride_dk, + nhead_stride_dv, batch_stride_lsed}, // args for common karg {}, // placeholder for bias {}, // placeholder for dbias @@ -452,6 +458,8 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dk, + ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_lsed, ck_tile::index_t split_stride_dq_acc, @@ -491,6 +499,8 @@ struct FmhaBwdDQDKDVKernel nhead_stride_do, nhead_stride_lsed, nhead_stride_dq_acc, + nhead_stride_dk, + nhead_stride_dv, batch_stride_lsed}, // args for common karg {}, // placeholder for bias {}, // placeholder for dbias @@ -687,10 +697,10 @@ struct FmhaBwdDQDKDVKernel static_cast(i_nhead) * kargs.nhead_stride_do + batch_offset_do; KGradDataType* dk_ptr = reinterpret_cast(kargs.dk_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_k + + static_cast(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk; VGradDataType* dv_ptr = reinterpret_cast(kargs.dv_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_v + + static_cast(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv; // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window