// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" #include "bias.hpp" #include #include #include struct FmhaBwdFp16 { }; struct FmhaBwdBf16 { }; template struct FmhaBwdTypeConfig; template <> struct FmhaBwdTypeConfig { using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; using VDataType = ck_tile::half_t; using GemmDataType = ck_tile::half_t; using BiasDataType = ck_tile::half_t; using LSEDataType = float; using AccDataType = float; // data type for gemm accumulation using DDataType = float; using RandValOutputDataType = uint8_t; using ODataType = ck_tile::half_t; using OGradDataType = ck_tile::half_t; using QGradDataType = ck_tile::half_t; using KGradDataType = ck_tile::half_t; using VGradDataType = ck_tile::half_t; using BiasGradDataType = ck_tile::half_t; }; template <> struct FmhaBwdTypeConfig { using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t; using GemmDataType = ck_tile::bf16_t; using BiasDataType = ck_tile::bf16_t; using LSEDataType = float; using AccDataType = float; // data type for gemm accumulation using DDataType = float; using RandValOutputDataType = uint8_t; using ODataType = ck_tile::bf16_t; using OGradDataType = ck_tile::bf16_t; using QGradDataType = ck_tile::bf16_t; using KGradDataType = ck_tile::bf16_t; using VGradDataType = ck_tile::bf16_t; using BiasGradDataType = ck_tile::bf16_t; }; struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; using GenericMask = ck_tile::GenericAttentionMask; using CausalMask = ck_tile::GenericAttentionMask; }; // runtime args, some will passed to karg, some will used to compute grids/blocks struct fmha_bwd_args { const void* q_ptr; const void* k_ptr; const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer const void* o_ptr; const void* lse_ptr; const void* do_ptr; void* d_ptr; void* rand_val_ptr; void* dq_ptr; void* dk_ptr; void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; ck_tile::index_t max_seqlen_q; ck_tile::index_t max_seqlen_k; ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; float scale; ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_o; ck_tile::index_t stride_randval; ck_tile::index_t stride_do; ck_tile::index_t stride_dq_acc; ck_tile::index_t stride_dq; ck_tile::index_t stride_dk; ck_tile::index_t stride_dv; ck_tile::index_t stride_dbias; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_dq_acc; ck_tile::index_t nhead_stride_dq; ck_tile::index_t nhead_stride_dk; ck_tile::index_t nhead_stride_dv; ck_tile::index_t nhead_stride_dbias; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; ck_tile::index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dq; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dbias; ck_tile::index_t split_stride_dq_acc; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; float p_drop; float p_undrop; std::variant, std::pair> drop_seed_offset; }; template auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0; const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr; const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq; const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq; const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq; // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, args.k_ptr, args.v_ptr, args.bias_ptr, args.lse_ptr, args.do_ptr, args.d_ptr, args.rand_val_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, args.scale, args.stride_q, args.stride_k, args.stride_v, args.stride_bias, args.stride_randval, args.stride_do, stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, args.drop_seed_offset); } else { // create batch mode kernel arguments return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, args.k_ptr, args.v_ptr, args.bias_ptr, args.lse_ptr, args.do_ptr, args.d_ptr, args.rand_val_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, dq_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, args.scale, args.stride_q, args.stride_k, args.stride_v, args.stride_bias, args.stride_randval, args.stride_do, stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, args.batch_stride_randval, args.batch_stride_do, args.batch_stride_lsed, batch_stride_dq, args.batch_stride_dk, args.batch_stride_dv, args.batch_stride_dbias, args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, args.drop_seed_offset); } }(); dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); return ck_tile::make_tuple(kargs, grids); } template auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) { auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) { return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, args.p_undrop, args.seqstart_q_ptr, args.hdim_v, args.stride_do, args.stride_o, args.nhead_stride_do, args.nhead_stride_o, args.nhead_stride_lsed); } else { // create batch mode kernel arguments return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, args.p_undrop, args.seqlen_q, args.hdim_v, args.stride_do, args.stride_o, args.nhead_stride_do, args.nhead_stride_o, args.nhead_stride_lsed, args.batch_stride_do, args.batch_stride_o, args.batch_stride_lsed); } }(); dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); return ck_tile::make_tuple(kargs, grids); } template auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) { auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) { return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, args.dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.hdim_q, args.stride_dq, args.stride_dq_acc, args.nhead_stride_dq, args.nhead_stride_dq_acc, args.split_stride_dq_acc); } else { // create batch mode kernel arguments return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, args.dq_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.stride_dq, args.stride_dq_acc, args.nhead_stride_dq, args.nhead_stride_dq_acc, args.batch_stride_dq, args.batch_stride_dq_acc, args.split_stride_dq_acc); } }(); dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); return ck_tile::make_tuple(kargs, grids); } // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct fmha_bwd_dq_dk_dv_traits_ { }; template float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); template void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_dq_dk_dv_get_name_(); template int fmha_bwd_dq_dk_dv_maxq_(); template struct fmha_bwd_dot_do_o_traits_ { static constexpr ck_tile::index_t HDim = HDim_; using DataType = ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadDv = kPadDv_; }; template float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); template void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_dot_do_o_get_name_(); template struct fmha_bwd_convert_dq_traits_ { }; template float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); template void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_convert_dq_get_name_(); // This is the public API, will be generated by script struct fmha_bwd_traits { int hdim_q; int hdim_v; std::string data_type; bool is_group_mode; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_dbias; bool has_dropout; bool is_store_randval; bool is_deterministic; // TODO: padding check is inside this api }; template float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);