// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" #include "bias.hpp" #include #include #include struct FmhaBwdFp32 { }; struct FmhaBwdFp16 { }; struct FmhaBwdBf16 { }; template struct FmhaBwdTypeConfig; template <> struct FmhaBwdTypeConfig { using QDataType = float; using KDataType = float; using VDataType = float; using GemmDataType = float; using BiasDataType = float; using LSEDataType = float; using AccDataType = float; // data type for gemm accumulation using DDataType = float; using RandValOutputDataType = uint8_t; using ODataType = float; using OGradDataType = float; using QGradDataType = float; using KGradDataType = float; using VGradDataType = float; using BiasGradDataType = float; }; template <> struct FmhaBwdTypeConfig { using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; using VDataType = ck_tile::half_t; using GemmDataType = ck_tile::half_t; using BiasDataType = ck_tile::half_t; using LSEDataType = float; using AccDataType = float; // data type for gemm accumulation using DDataType = float; using RandValOutputDataType = uint8_t; using ODataType = ck_tile::half_t; using OGradDataType = ck_tile::half_t; using QGradDataType = ck_tile::half_t; using KGradDataType = ck_tile::half_t; using VGradDataType = ck_tile::half_t; using BiasGradDataType = ck_tile::half_t; }; template <> struct FmhaBwdTypeConfig { using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t; using GemmDataType = ck_tile::bf16_t; using BiasDataType = ck_tile::bf16_t; using LSEDataType = float; using AccDataType = float; // data type for gemm accumulation using DDataType = float; using RandValOutputDataType = uint8_t; using ODataType = ck_tile::bf16_t; using OGradDataType = ck_tile::bf16_t; using QGradDataType = ck_tile::bf16_t; using KGradDataType = ck_tile::bf16_t; using VGradDataType = ck_tile::bf16_t; using BiasGradDataType = ck_tile::bf16_t; }; struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; using GenericMask = ck_tile::GenericAttentionMask; using CausalMask = ck_tile::GenericAttentionMask; }; // runtime args, some will passed to karg, some will used to compute grids/blocks struct fmha_bwd_args { const void* q_ptr; const void* k_ptr; const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer const void* o_ptr; const void* lse_ptr; const void* do_ptr; void* d_ptr; void* rand_val_ptr; void* dq_ptr; void* dk_ptr; void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; // Usage notes for sequence length pointer parameters: // // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles // MQA/GQA..."] // // With padding: // Group mode: // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence // lengths. [array size: batch + 1] // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each // sequence. [array size: batch] // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) // sequence lengths. [array size: batch + 1] // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually // exclusive. Use one set, not both. // // Batch mode: // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) // sequence lengths. [array size: batch + 1] // - seqstart_* and seqlen_* pointers must be nullptr. // // Without padding: // (Note: Physical length equals logical length) // // Group mode: // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array // size: batch + 1] // - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr. // // Batch mode: // - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr. // const void* seqstart_q_ptr = nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) const void* seqstart_k_ptr = nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array // [batch]. (Used in Group mode with padding) const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array // [batch]. (Used in Group mode with padding) const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; ck_tile::index_t max_seqlen_q; ck_tile::index_t max_seqlen_k; ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; float scale; ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_o; ck_tile::index_t stride_randval; ck_tile::index_t stride_do; ck_tile::index_t stride_dq_acc; ck_tile::index_t stride_dq; ck_tile::index_t stride_dk; ck_tile::index_t stride_dv; ck_tile::index_t stride_dbias; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_lsed; ck_tile::long_index_t nhead_stride_dq_acc; ck_tile::index_t nhead_stride_dq; ck_tile::index_t nhead_stride_dk; ck_tile::index_t nhead_stride_dv; ck_tile::index_t nhead_stride_dbias; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_lsed; ck_tile::long_index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dq; ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dbias; ck_tile::index_t split_stride_dq_acc; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; float p_drop; float p_undrop; std::variant, std::pair> drop_seed_offset; }; template auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0; const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr; const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq; const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq; const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq; // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, args.k_ptr, args.v_ptr, args.bias_ptr, args.lse_ptr, args.do_ptr, args.d_ptr, args.rand_val_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, args.scale, args.stride_q, args.stride_k, args.stride_v, args.stride_bias, args.stride_randval, args.stride_do, stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, args.drop_seed_offset); } else { // create batch mode kernel arguments return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, args.k_ptr, args.v_ptr, args.bias_ptr, args.lse_ptr, args.do_ptr, args.d_ptr, args.rand_val_ptr, args.dk_ptr, args.dv_ptr, args.dbias_ptr, dq_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, args.scale, args.stride_q, args.stride_k, args.stride_v, args.stride_bias, args.stride_randval, args.stride_do, stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, args.batch_stride_randval, args.batch_stride_do, args.batch_stride_lsed, batch_stride_dq, args.batch_stride_dk, args.batch_stride_dv, args.batch_stride_dbias, args.split_stride_dq_acc, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, args.drop_seed_offset); } }(); dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); return ck_tile::make_tuple(kargs, grids); } template auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) { auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) { return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, args.p_undrop, args.seqstart_q_ptr, args.seqlen_q_ptr, args.cu_seqlen_q_ptr, args.hdim_v, args.stride_do, args.stride_o, args.nhead_stride_do, args.nhead_stride_o, args.nhead_stride_lsed); } else { // create batch mode kernel arguments return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, args.p_undrop, args.seqlen_q, args.hdim_v, args.stride_do, args.stride_o, args.nhead_stride_do, args.nhead_stride_o, args.nhead_stride_lsed, args.batch_stride_do, args.batch_stride_o, args.batch_stride_lsed); } }(); dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); return ck_tile::make_tuple(kargs, grids); } template auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) { auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) { return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, args.dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.hdim_q, args.stride_dq, args.stride_dq_acc, args.nhead_stride_dq, args.nhead_stride_dq_acc, args.split_stride_dq_acc); } else { // create batch mode kernel arguments return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, args.dq_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.stride_dq, args.stride_dq_acc, args.nhead_stride_dq, args.nhead_stride_dq_acc, args.batch_stride_dq, args.batch_stride_dq_acc, args.split_stride_dq_acc); } }(); dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); return ck_tile::make_tuple(kargs, grids); } // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct fmha_bwd_dq_dk_dv_traits_ { }; template float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); template void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_dq_dk_dv_get_name_(); template int fmha_bwd_dq_dk_dv_maxq_(); template struct fmha_bwd_dot_do_o_traits_ { static constexpr ck_tile::index_t HDim = HDim_; using DataType = ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadDv = kPadDv_; }; template float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); template void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_dot_do_o_get_name_(); template struct fmha_bwd_convert_dq_traits_ { }; template float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); template void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_convert_dq_get_name_(); // This is the public API, will be generated by script struct fmha_bwd_traits { int hdim_q; int hdim_v; std::string data_type; bool is_group_mode; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_dbias; bool has_dropout; bool is_store_randval; bool is_deterministic; // TODO: padding check is inside this api }; template float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);