mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)
* Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -600,6 +600,19 @@ struct SimplifiedRatioAttentionMask
|
||||
mdiv y_ratio_mdiv;
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct is_generic_attention_mask : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <bool IsMasking, bool IsLocal>
|
||||
struct is_generic_attention_mask<GenericAttentionMask<IsMasking, IsLocal>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Mask>
|
||||
static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Mask>::value;
|
||||
|
||||
// TODO: prefer use this function in host code
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
|
||||
@@ -73,54 +73,6 @@ struct FmhaFwdKernel
|
||||
#endif
|
||||
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
|
||||
|
||||
// clang-format off
|
||||
template <typename T1, typename T2 = T1> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
|
||||
// clang-format on
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
using bfs = typename FmhaPipeline::BlockFmhaShape;
|
||||
using g0br = typename bfs::Gemm0BlockWarps;
|
||||
using g1br = typename bfs::Gemm1BlockWarps;
|
||||
using g0wt = typename bfs::Gemm0WarpTile;
|
||||
using g1wt = typename bfs::Gemm1WarpTile;
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadSeqLenQ) n += "s";
|
||||
if (kPadSeqLenK) n += "sk";
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
|
||||
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
|
||||
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) +
|
||||
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name)) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct FmhaFwdEmptyKargs
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and
|
||||
/// instruction scheduling optimizations.
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdV3Kernel
|
||||
{
|
||||
@@ -103,8 +105,8 @@ struct FmhaFwdV3Kernel
|
||||
|
||||
// Optional cumulative sequence length pointers for batch mode
|
||||
// If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
@@ -114,12 +116,13 @@ struct FmhaFwdV3Kernel
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
// Optional cumulative padded sequence starts (including PAD tokens)
|
||||
// Used solely to compute memory offsets when sequences are physically padded.
|
||||
const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
|
||||
const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
|
||||
const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -156,8 +159,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt,
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -199,8 +202,8 @@ struct FmhaFwdV3Kernel
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
|
||||
kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
|
||||
kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -213,6 +216,7 @@ struct FmhaFwdV3Kernel
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -232,8 +236,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt,
|
||||
const void* seqstart_padded_q_ptr = nullptr,
|
||||
const void* seqstart_padded_k_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -258,6 +262,7 @@ struct FmhaFwdV3Kernel
|
||||
{}, // placeholder for lse
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kHasMask)
|
||||
@@ -273,30 +278,29 @@ struct FmhaFwdV3Kernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
|
||||
kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
|
||||
kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
|
||||
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
|
||||
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
if constexpr(kHasMask)
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
return dim3(nhead,
|
||||
batch_size,
|
||||
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
return dim3(nhead,
|
||||
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
|
||||
batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,13 +348,20 @@ struct FmhaFwdV3Kernel
|
||||
// FmhaPipeline::kN1);
|
||||
|
||||
// assume that num_tile_n1 is always 1
|
||||
if constexpr(kHasMask)
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
const index_t i_batch = blockIdx.y;
|
||||
const index_t i_block = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -358,7 +369,14 @@ struct FmhaFwdV3Kernel
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,32 +408,36 @@ struct FmhaFwdV3Kernel
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
|
||||
// Use seqstart_q_ptr and seqstart_k_ptr for physical starts
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
|
||||
? kargs.seqstart_padded_q_ptr[i_batch]
|
||||
: query_start_unpadded;
|
||||
const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
|
||||
? kargs.seqstart_padded_k_ptr[i_batch]
|
||||
: key_start_unpadded;
|
||||
|
||||
batch_offset_q = query_start_padded * kargs.stride_q;
|
||||
batch_offset_k = key_start_padded * kargs.stride_k;
|
||||
batch_offset_v = key_start_padded * kargs.stride_v;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
// LSE layout is [nhead, total_seqlen], index by unpadded start
|
||||
batch_offset_lse = query_start_unpadded;
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
batch_offset_o = query_start_padded * kargs.stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
// Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
|
||||
if(kargs.seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
}
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
@@ -427,10 +449,14 @@ struct FmhaFwdV3Kernel
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -450,10 +476,10 @@ struct FmhaFwdV3Kernel
|
||||
kargs.seqlen_q =
|
||||
kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
|
||||
}
|
||||
if(kargs.cu_seqlen_kv_ptr != nullptr)
|
||||
if(kargs.cu_seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k =
|
||||
kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
|
||||
kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
@@ -246,6 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and
|
||||
/// instruction scheduling optimizations.
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
|
||||
struct BlockFmhaFwdV3Pipeline
|
||||
{
|
||||
@@ -261,12 +265,16 @@ struct BlockFmhaFwdV3Pipeline
|
||||
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename Problem::FmhaMask>;
|
||||
static_assert(is_generic_attention_mask_v<FmhaMask>);
|
||||
|
||||
static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
|
||||
"we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
|
||||
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
@@ -277,14 +285,24 @@ struct BlockFmhaFwdV3Pipeline
|
||||
static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ;
|
||||
static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
|
||||
!kStoreLSE && !kHasDropout &&
|
||||
(QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) &&
|
||||
!kSkipMinSeqlenQ),
|
||||
"enable unsupported features");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
|
||||
@@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum
|
||||
QRKSVS_ASYNC,
|
||||
QSKSVS,
|
||||
QRKSVS_ASYNC_TRLOAD,
|
||||
QRKSVS_ASYNC_TRLOAD_V3,
|
||||
};
|
||||
|
||||
template <BlockFmhaPipelineEnum>
|
||||
|
||||
@@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaFwdV3PipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -166,20 +166,4 @@ struct TileFmhaBwdConvertQGradTraits
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kStoreLSE_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdV3Traits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user