mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* chore(copyright): update copyright header for codegen directory * chore(copyright): update copyright header for example directory
523 lines
24 KiB
C++
523 lines
24 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/device_prop.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/fmha.hpp"
|
|
#include "ck_tile/ops/epilogue.hpp"
|
|
#include "mask.hpp"
|
|
#include "bias.hpp"
|
|
|
|
#include <type_traits>
|
|
#include <utility>
|
|
#include <variant>
|
|
|
|
struct FmhaBwdFp32
|
|
{
|
|
};
|
|
|
|
struct FmhaBwdFp16
|
|
{
|
|
};
|
|
|
|
struct FmhaBwdBf16
|
|
{
|
|
};
|
|
|
|
template <typename DataType>
|
|
struct FmhaBwdTypeConfig;
|
|
|
|
template <>
|
|
struct FmhaBwdTypeConfig<FmhaBwdFp32>
|
|
{
|
|
using QDataType = float;
|
|
using KDataType = float;
|
|
using VDataType = float;
|
|
using GemmDataType = float;
|
|
using BiasDataType = float;
|
|
using LSEDataType = float;
|
|
using AccDataType = float; // data type for gemm accumulation
|
|
using DDataType = float;
|
|
using RandValOutputDataType = uint8_t;
|
|
using ODataType = float;
|
|
using OGradDataType = float;
|
|
using QGradDataType = float;
|
|
using KGradDataType = float;
|
|
using VGradDataType = float;
|
|
using BiasGradDataType = float;
|
|
};
|
|
|
|
template <>
|
|
struct FmhaBwdTypeConfig<FmhaBwdFp16>
|
|
{
|
|
using QDataType = ck_tile::half_t;
|
|
using KDataType = ck_tile::half_t;
|
|
using VDataType = ck_tile::half_t;
|
|
using GemmDataType = ck_tile::half_t;
|
|
using BiasDataType = ck_tile::half_t;
|
|
using LSEDataType = float;
|
|
using AccDataType = float; // data type for gemm accumulation
|
|
using DDataType = float;
|
|
using RandValOutputDataType = uint8_t;
|
|
using ODataType = ck_tile::half_t;
|
|
using OGradDataType = ck_tile::half_t;
|
|
using QGradDataType = ck_tile::half_t;
|
|
using KGradDataType = ck_tile::half_t;
|
|
using VGradDataType = ck_tile::half_t;
|
|
using BiasGradDataType = ck_tile::half_t;
|
|
};
|
|
|
|
template <>
|
|
struct FmhaBwdTypeConfig<FmhaBwdBf16>
|
|
{
|
|
using QDataType = ck_tile::bf16_t;
|
|
using KDataType = ck_tile::bf16_t;
|
|
using VDataType = ck_tile::bf16_t;
|
|
using GemmDataType = ck_tile::bf16_t;
|
|
using BiasDataType = ck_tile::bf16_t;
|
|
using LSEDataType = float;
|
|
using AccDataType = float; // data type for gemm accumulation
|
|
using DDataType = float;
|
|
using RandValOutputDataType = uint8_t;
|
|
using ODataType = ck_tile::bf16_t;
|
|
using OGradDataType = ck_tile::bf16_t;
|
|
using QGradDataType = ck_tile::bf16_t;
|
|
using KGradDataType = ck_tile::bf16_t;
|
|
using VGradDataType = ck_tile::bf16_t;
|
|
using BiasGradDataType = ck_tile::bf16_t;
|
|
};
|
|
|
|
struct FmhaMasks
|
|
{
|
|
using NoMask = ck_tile::GenericAttentionMask<false>;
|
|
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
|
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
|
};
|
|
|
|
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
|
struct fmha_bwd_args
|
|
{
|
|
const void* q_ptr;
|
|
const void* k_ptr;
|
|
const void* v_ptr;
|
|
const void* bias_ptr; // bias or alibi_slope pointer
|
|
const void* o_ptr;
|
|
const void* lse_ptr;
|
|
const void* do_ptr;
|
|
void* d_ptr;
|
|
void* rand_val_ptr;
|
|
void* dq_ptr;
|
|
void* dk_ptr;
|
|
void* dv_ptr;
|
|
void* dbias_ptr;
|
|
void* dq_acc_ptr;
|
|
|
|
// Usage notes for sequence length pointer parameters:
|
|
//
|
|
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
|
|
// MQA/GQA..."]
|
|
//
|
|
// With padding:
|
|
// Group mode:
|
|
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
|
// lengths. [array size: batch + 1]
|
|
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
|
// sequence. [array size: batch]
|
|
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
|
// sequence lengths. [array size: batch + 1]
|
|
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
|
// exclusive. Use one set, not both.
|
|
//
|
|
// Batch mode:
|
|
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
|
// sequence lengths. [array size: batch + 1]
|
|
// - seqstart_* and seqlen_* pointers must be nullptr.
|
|
//
|
|
// Without padding:
|
|
// (Note: Physical length equals logical length)
|
|
//
|
|
// Group mode:
|
|
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
|
|
// size: batch + 1]
|
|
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
|
|
//
|
|
// Batch mode:
|
|
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
|
|
//
|
|
const void* seqstart_q_ptr =
|
|
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
|
const void* seqstart_k_ptr =
|
|
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
|
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
|
// [batch]. (Used in Group mode with padding)
|
|
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
|
// [batch]. (Used in Group mode with padding)
|
|
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
|
// array [batch + 1]. (Used with padding)
|
|
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
|
// array [batch + 1]. (Used with padding)
|
|
ck_tile::index_t seqlen_q;
|
|
ck_tile::index_t seqlen_k;
|
|
ck_tile::index_t batch;
|
|
ck_tile::index_t max_seqlen_q;
|
|
ck_tile::index_t max_seqlen_k;
|
|
ck_tile::index_t hdim_q;
|
|
ck_tile::index_t hdim_v;
|
|
ck_tile::index_t nhead_q;
|
|
ck_tile::index_t nhead_k;
|
|
float scale;
|
|
ck_tile::index_t stride_q;
|
|
ck_tile::index_t stride_k;
|
|
ck_tile::index_t stride_v;
|
|
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
|
ck_tile::index_t stride_o;
|
|
ck_tile::index_t stride_randval;
|
|
ck_tile::index_t stride_do;
|
|
ck_tile::index_t stride_dq_acc;
|
|
ck_tile::index_t stride_dq;
|
|
ck_tile::index_t stride_dk;
|
|
ck_tile::index_t stride_dv;
|
|
ck_tile::index_t stride_dbias;
|
|
ck_tile::index_t nhead_stride_q;
|
|
ck_tile::index_t nhead_stride_k;
|
|
ck_tile::index_t nhead_stride_v;
|
|
ck_tile::index_t nhead_stride_bias;
|
|
ck_tile::index_t nhead_stride_o;
|
|
ck_tile::index_t nhead_stride_randval;
|
|
ck_tile::index_t nhead_stride_do;
|
|
ck_tile::index_t nhead_stride_lsed;
|
|
ck_tile::index_t nhead_stride_dq_acc;
|
|
ck_tile::index_t nhead_stride_dq;
|
|
ck_tile::index_t nhead_stride_dk;
|
|
ck_tile::index_t nhead_stride_dv;
|
|
ck_tile::index_t nhead_stride_dbias;
|
|
ck_tile::index_t batch_stride_q;
|
|
ck_tile::index_t batch_stride_k;
|
|
ck_tile::index_t batch_stride_v;
|
|
ck_tile::index_t batch_stride_bias;
|
|
ck_tile::index_t batch_stride_o;
|
|
ck_tile::index_t batch_stride_randval;
|
|
ck_tile::index_t batch_stride_do;
|
|
ck_tile::index_t batch_stride_lsed;
|
|
ck_tile::index_t batch_stride_dq_acc;
|
|
ck_tile::index_t batch_stride_dq;
|
|
ck_tile::index_t batch_stride_dk;
|
|
ck_tile::index_t batch_stride_dv;
|
|
ck_tile::index_t batch_stride_dbias;
|
|
ck_tile::index_t split_stride_dq_acc;
|
|
ck_tile::index_t window_size_left;
|
|
ck_tile::index_t window_size_right;
|
|
ck_tile::index_t mask_type;
|
|
float p_drop;
|
|
float p_undrop;
|
|
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
|
drop_seed_offset;
|
|
};
|
|
|
|
template <typename FmhaBwdDQDKDVKernel>
|
|
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
|
{
|
|
assert(args.nhead_q % args.nhead_k == 0);
|
|
auto kargs = [&] {
|
|
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
|
|
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
|
|
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
|
|
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
|
|
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
|
|
|
|
// create group mode kernel arguments
|
|
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
|
{
|
|
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
|
args.k_ptr,
|
|
args.v_ptr,
|
|
args.bias_ptr,
|
|
args.lse_ptr,
|
|
args.do_ptr,
|
|
args.d_ptr,
|
|
args.rand_val_ptr,
|
|
args.dk_ptr,
|
|
args.dv_ptr,
|
|
args.dbias_ptr,
|
|
dq_ptr,
|
|
args.seqstart_q_ptr,
|
|
args.seqstart_k_ptr,
|
|
args.seqlen_q_ptr,
|
|
args.seqlen_k_ptr,
|
|
args.cu_seqlen_q_ptr,
|
|
args.cu_seqlen_k_ptr,
|
|
args.hdim_q,
|
|
args.hdim_v,
|
|
args.nhead_q,
|
|
args.nhead_q / args.nhead_k,
|
|
args.scale,
|
|
args.stride_q,
|
|
args.stride_k,
|
|
args.stride_v,
|
|
args.stride_bias,
|
|
args.stride_randval,
|
|
args.stride_do,
|
|
stride_dq,
|
|
args.stride_dk,
|
|
args.stride_dv,
|
|
args.stride_dbias,
|
|
args.nhead_stride_q,
|
|
args.nhead_stride_k,
|
|
args.nhead_stride_v,
|
|
args.nhead_stride_bias,
|
|
args.nhead_stride_randval,
|
|
args.nhead_stride_do,
|
|
args.nhead_stride_lsed,
|
|
nhead_stride_dq,
|
|
args.nhead_stride_dk,
|
|
args.nhead_stride_dv,
|
|
args.nhead_stride_dbias,
|
|
args.split_stride_dq_acc,
|
|
args.window_size_left,
|
|
args.window_size_right,
|
|
args.mask_type,
|
|
args.p_drop,
|
|
args.drop_seed_offset);
|
|
}
|
|
else
|
|
{ // create batch mode kernel arguments
|
|
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
|
args.k_ptr,
|
|
args.v_ptr,
|
|
args.bias_ptr,
|
|
args.lse_ptr,
|
|
args.do_ptr,
|
|
args.d_ptr,
|
|
args.rand_val_ptr,
|
|
args.dk_ptr,
|
|
args.dv_ptr,
|
|
args.dbias_ptr,
|
|
dq_ptr,
|
|
args.seqlen_q,
|
|
args.seqlen_k,
|
|
args.hdim_q,
|
|
args.hdim_v,
|
|
args.nhead_q,
|
|
args.nhead_q / args.nhead_k,
|
|
args.scale,
|
|
args.stride_q,
|
|
args.stride_k,
|
|
args.stride_v,
|
|
args.stride_bias,
|
|
args.stride_randval,
|
|
args.stride_do,
|
|
stride_dq,
|
|
args.stride_dk,
|
|
args.stride_dv,
|
|
args.stride_dbias,
|
|
args.nhead_stride_q,
|
|
args.nhead_stride_k,
|
|
args.nhead_stride_v,
|
|
args.nhead_stride_bias,
|
|
args.nhead_stride_randval,
|
|
args.nhead_stride_do,
|
|
args.nhead_stride_lsed,
|
|
nhead_stride_dq,
|
|
args.nhead_stride_dk,
|
|
args.nhead_stride_dv,
|
|
args.nhead_stride_dbias,
|
|
args.batch_stride_q,
|
|
args.batch_stride_k,
|
|
args.batch_stride_v,
|
|
args.batch_stride_bias,
|
|
args.batch_stride_randval,
|
|
args.batch_stride_do,
|
|
args.batch_stride_lsed,
|
|
batch_stride_dq,
|
|
args.batch_stride_dk,
|
|
args.batch_stride_dv,
|
|
args.batch_stride_dbias,
|
|
args.split_stride_dq_acc,
|
|
args.window_size_left,
|
|
args.window_size_right,
|
|
args.mask_type,
|
|
args.p_drop,
|
|
args.drop_seed_offset);
|
|
}
|
|
}();
|
|
|
|
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
|
|
return ck_tile::make_tuple(kargs, grids);
|
|
}
|
|
|
|
template <typename FmhaBwdOGradDotOKernel>
|
|
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
|
{
|
|
auto kargs = [&] {
|
|
// create group mode kernel arguments
|
|
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
|
|
{
|
|
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
|
args.do_ptr,
|
|
args.d_ptr,
|
|
args.p_undrop,
|
|
args.seqstart_q_ptr,
|
|
args.seqlen_q_ptr,
|
|
args.cu_seqlen_q_ptr,
|
|
args.hdim_v,
|
|
args.stride_do,
|
|
args.stride_o,
|
|
args.nhead_stride_do,
|
|
args.nhead_stride_o,
|
|
args.nhead_stride_lsed);
|
|
}
|
|
else
|
|
{ // create batch mode kernel arguments
|
|
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
|
args.do_ptr,
|
|
args.d_ptr,
|
|
args.p_undrop,
|
|
args.seqlen_q,
|
|
args.hdim_v,
|
|
args.stride_do,
|
|
args.stride_o,
|
|
args.nhead_stride_do,
|
|
args.nhead_stride_o,
|
|
args.nhead_stride_lsed,
|
|
args.batch_stride_do,
|
|
args.batch_stride_o,
|
|
args.batch_stride_lsed);
|
|
}
|
|
}();
|
|
|
|
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
|
return ck_tile::make_tuple(kargs, grids);
|
|
}
|
|
|
|
template <typename FmhaBwdConvertQGradKernel>
|
|
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
|
{
|
|
auto kargs = [&] {
|
|
// create group mode kernel arguments
|
|
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
|
{
|
|
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
|
args.dq_ptr,
|
|
args.seqstart_q_ptr,
|
|
args.seqstart_k_ptr,
|
|
args.seqlen_q_ptr,
|
|
args.seqlen_k_ptr,
|
|
args.cu_seqlen_q_ptr,
|
|
args.cu_seqlen_k_ptr,
|
|
args.hdim_q,
|
|
args.stride_dq,
|
|
args.stride_dq_acc,
|
|
args.nhead_stride_dq,
|
|
args.nhead_stride_dq_acc,
|
|
args.split_stride_dq_acc);
|
|
}
|
|
else
|
|
{ // create batch mode kernel arguments
|
|
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
|
args.dq_ptr,
|
|
args.seqlen_q,
|
|
args.seqlen_k,
|
|
args.hdim_q,
|
|
args.stride_dq,
|
|
args.stride_dq_acc,
|
|
args.nhead_stride_dq,
|
|
args.nhead_stride_dq_acc,
|
|
args.batch_stride_dq,
|
|
args.batch_stride_dq_acc,
|
|
args.split_stride_dq_acc);
|
|
}
|
|
}();
|
|
|
|
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
|
return ck_tile::make_tuple(kargs, grids);
|
|
}
|
|
|
|
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
|
template <ck_tile::index_t HDim_,
|
|
typename DataType_,
|
|
bool kIsGroupMode_,
|
|
typename FmhaMask_,
|
|
typename FmhaDropout_,
|
|
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
|
bool kHasBiasGrad_,
|
|
ck_tile::index_t kPadD_,
|
|
ck_tile::index_t kPadDv_,
|
|
bool kIsDeterministic_,
|
|
bool kUseTrLoad_,
|
|
ck_tile::index_t MaxSeqLenQ_,
|
|
ck_tile::index_t kN0>
|
|
struct fmha_bwd_dq_dk_dv_traits_
|
|
{
|
|
};
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
std::string fmha_bwd_dq_dk_dv_get_name_();
|
|
template <typename Traits_, typename Arch = void>
|
|
int fmha_bwd_dq_dk_dv_maxq_();
|
|
|
|
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
|
struct fmha_bwd_dot_do_o_traits_
|
|
{
|
|
static constexpr ck_tile::index_t HDim = HDim_;
|
|
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
|
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
|
static constexpr bool kPadS = kPadS_;
|
|
static constexpr bool kPadDv = kPadDv_;
|
|
};
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
std::string fmha_bwd_dot_do_o_get_name_();
|
|
|
|
template <ck_tile::index_t HDim_,
|
|
typename DataType_,
|
|
bool kIsGroupMode_,
|
|
bool kPadS_,
|
|
bool kPadD_,
|
|
bool kIsDeterministic_,
|
|
ck_tile::index_t kN0>
|
|
struct fmha_bwd_convert_dq_traits_
|
|
{
|
|
};
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
|
|
|
template <typename Traits_, typename Arch = void>
|
|
std::string fmha_bwd_convert_dq_get_name_();
|
|
|
|
// This is the public API, will be generated by script
|
|
struct fmha_bwd_traits
|
|
{
|
|
int hdim_q;
|
|
int hdim_v;
|
|
std::string data_type;
|
|
bool is_group_mode;
|
|
mask_enum mask_type;
|
|
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
|
bool has_dbias;
|
|
bool has_dropout;
|
|
bool is_store_randval;
|
|
bool is_deterministic;
|
|
// TODO: padding check is inside this api
|
|
};
|
|
template <int Version = 2>
|
|
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|