Update for xformers (#2372)

* update api

* update kernel api

* clang-format
This commit is contained in:
Max Podkorytov
2025-06-22 00:28:30 -07:00
committed by GitHub
parent cebdee4d9e
commit 0366fb2abc
2 changed files with 18 additions and 2 deletions

View File

@@ -847,6 +847,7 @@ struct FmhaFwdKernel
window_size_left,
window_size_right,
mask_type,
0, // min_seqlen_q
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));

View File

@@ -28,6 +28,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
@@ -54,6 +55,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -127,7 +129,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
@@ -146,6 +150,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& /* unused */,
const AttentionVariantParams& /* unused */,
const BlockIndices& /* unused */,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -890,7 +897,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -901,6 +910,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
{
@@ -921,6 +933,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}