mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
add per_token_head quantization to fmha_fwd
This commit is contained in:
@@ -1062,6 +1062,20 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
|
||||
# PER_TOKEN_HEAD: FP8 fine-grained Q/K per-token-per-head + V per-head.
|
||||
# Currently only wired for fp8bf16 + qr_async + hdim=128 on gfx9 (MI300/MI308).
|
||||
# The non-paged fmha_fwd qr_async pipeline gained PER_TOKEN_HEAD support in
|
||||
# block_fmha_pipeline_qr_ks_vs_async.hpp; keep this scope tight until we
|
||||
# also extend qr / qr_async_trload / V3 pipelines.
|
||||
if dtype in cls._DT_FP8BF16 and hdim == 128:
|
||||
for logits, mask, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["f", "t"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, "no", "f", "f", "per_token_head", mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, "no", "f", "f", "per_token_head", mask, "f", "f", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
|
||||
@@ -305,6 +305,41 @@ struct FmhaFwdKernel
|
||||
const int32_t* seqstart_v_scale_ptr;
|
||||
};
|
||||
|
||||
// PER_TOKEN_HEAD descale layout (non-paged fmha_fwd):
|
||||
// q_descale: [B, total_q, nhead_q] fp32 (batch+token row stride, head col stride)
|
||||
// k_descale: [B, total_k, nhead_k] fp32
|
||||
// v_descale: [nhead_k] fp32 (per-head only; no batch / token dim)
|
||||
// For varlen (group) mode the batch dim is collapsed into total_q / total_k and the
|
||||
// per-batch offset is derived from cu_seqlens * stride_*_descale instead of
|
||||
// batch_stride_*_descale.
|
||||
struct FmhaFwdCommonPerTokenHeadKargs : FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
// Per-token (row) strides; V is per-head only, so stride_v_descale is unused.
|
||||
ck_tile::index_t stride_q_descale;
|
||||
ck_tile::index_t stride_k_descale;
|
||||
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
|
||||
// Unused under PER_TOKEN_HEAD but the qr_async pipeline takes it positionally;
|
||||
// keep it as a no-op field rather than baking a constant into the call site.
|
||||
ck_tile::index_t block_scale_size_kv = 128;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchPerTokenHeadKargs : FmhaFwdCommonPerTokenHeadKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale; // callers should set this to 0 (V is per-head)
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupPerTokenHeadKargs : FmhaFwdCommonPerTokenHeadKargs
|
||||
{
|
||||
// group mode resolves per-batch offsets from cu_seqlens * stride_*_descale,
|
||||
// so no dedicated batch_stride_*_descale is needed here.
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
@@ -383,11 +418,16 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdBatchBlockScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::MX,
|
||||
FmhaFwdBatchMXKargs,
|
||||
FmhaFwdEmptyKargs<3>>>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdBatchBlockScaleKargs,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::MX,
|
||||
FmhaFwdBatchMXKargs,
|
||||
std::conditional_t<QScaleEnum ==
|
||||
BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD,
|
||||
FmhaFwdBatchPerTokenHeadKargs,
|
||||
FmhaFwdEmptyKargs<3>>>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -414,11 +454,16 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdGroupBlockScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::MX,
|
||||
FmhaFwdGroupMXKargs,
|
||||
FmhaFwdEmptyKargs<3>>>>,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdGroupBlockScaleKargs,
|
||||
std::conditional_t<
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::MX,
|
||||
FmhaFwdGroupMXKargs,
|
||||
std::conditional_t<QScaleEnum ==
|
||||
BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD,
|
||||
FmhaFwdGroupPerTokenHeadKargs,
|
||||
FmhaFwdEmptyKargs<3>>>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -612,6 +657,28 @@ struct FmhaFwdKernel
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// PER_TOKEN_HEAD (non-paged fmha_fwd) descale layout:
|
||||
// q_descale: [total_q, nhead_q] fp32 (per-token, per-head)
|
||||
// k_descale: [total_k, nhead_k] fp32 (per-token, per-head)
|
||||
// v_descale: [nhead_k] fp32 (per-head only)
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
// per-token row strides (V is per-head scalar so stride_v_descale is unused).
|
||||
kargs.stride_q_descale = stride_q_descale;
|
||||
kargs.stride_k_descale = stride_k_descale;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.batch_stride_q_descale = batch_stride_q_descale;
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -1065,6 +1132,25 @@ struct FmhaFwdKernel
|
||||
|
||||
kargs.seqstart_v_scale_ptr = reinterpret_cast<const int32_t*>(seqstart_v_scale_ptr);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// PER_TOKEN_HEAD (non-paged fmha_fwd, group mode) descale layout:
|
||||
// q_descale: [total_q, nhead_q] fp32 (per-token, per-head)
|
||||
// k_descale: [total_k, nhead_k] fp32 (per-token, per-head)
|
||||
// v_descale: [nhead_k] fp32 (per-head only)
|
||||
// Per-batch offsets are derived from query_start/key_start in the kernel,
|
||||
// so we don't need batch_stride_*_descale here (group mode).
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.stride_q_descale = stride_q_descale;
|
||||
kargs.stride_k_descale = stride_k_descale;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -1574,6 +1660,16 @@ struct FmhaFwdKernel
|
||||
batch_offset_k_descale = key_start * kargs.stride_k_descale;
|
||||
batch_offset_v_descale = kargs.seqstart_v_scale_ptr[i_batch];
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// PER_TOKEN_HEAD descales: Q/K per-token-per-head, V per-head only.
|
||||
// q_descale: [total_q, nhead_q] -> per-batch offset = query_start * row_stride
|
||||
// k_descale: [total_k, nhead_k] -> per-batch offset = key_start * row_stride
|
||||
// v_descale: [nhead_k] -> no per-batch offset (shared across batches)
|
||||
batch_offset_q_descale = query_start * kargs.stride_q_descale;
|
||||
batch_offset_k_descale = key_start * kargs.stride_k_descale;
|
||||
batch_offset_v_descale = 0;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
@@ -1642,8 +1738,12 @@ struct FmhaFwdKernel
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::MX ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// For PER_TOKEN_HEAD V is per-head only (no per-batch dimension);
|
||||
// callers should set kargs.batch_stride_v_descale = 0 in that case
|
||||
// so this expression evaluates to 0 here.
|
||||
batch_offset_q_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
|
||||
batch_offset_k_descale =
|
||||
@@ -2121,6 +2221,66 @@ struct FmhaFwdKernel
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// PER_TOKEN_HEAD: Q/K per-(token, head), V per-head (FP8 fine-grained).
|
||||
// Resolve descale base ptrs to (batch, head); the pipeline indexes
|
||||
// per-token offsets within head using stride_q/k_descale.
|
||||
//
|
||||
// Layout convention (non-paged fmha_fwd):
|
||||
// q_descale: [total_q, nhead_q] fp32
|
||||
// k_descale: [total_k, nhead_k] fp32
|
||||
// v_descale: [nhead_k] fp32
|
||||
const float* q_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.q_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
|
||||
batch_offset_q_descale;
|
||||
const float* k_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.k_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k_descale +
|
||||
batch_offset_k_descale;
|
||||
const float* v_descale_ptr =
|
||||
reinterpret_cast<const float*>(kargs.v_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v_descale +
|
||||
batch_offset_v_descale;
|
||||
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func - PER_TOKEN_HEAD applies its own per-(row,col) scale
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func - V descale is folded in via 'o_acc += o_acc0 * v_descale'
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.block_scale_size_kv, // unused for PER_TOKEN_HEAD
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value,
|
||||
// PER_TOKEN_HEAD-only:
|
||||
q_descale_ptr,
|
||||
kargs.stride_q_descale,
|
||||
kargs.stride_k_descale);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
using QScaleDataType = typename FmhaPipeline::QScaleDataType;
|
||||
|
||||
@@ -204,7 +204,16 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const QScaleDramBlockWindowTmp&, // M0*(K0/kQKScaleGranularity) tile
|
||||
const KScaleDramBlockWindowTmp&, // N0*(K0/kQKScaleGranularity) tile
|
||||
const VScaleDramBlockWindowTmp&, // N1*(K1/kVScaleGranularity) tile
|
||||
const float sink_v) const
|
||||
const float sink_v,
|
||||
// PER_TOKEN_HEAD: per-(token, head) Q/K descales; V is per-head scalar.
|
||||
// q_descale_ptr / k_descale_ptr / v_descale_ptr are already
|
||||
// (batch, head)-resolved by the kernel; pipeline indexes
|
||||
// q_descale_ptr[(q_origin + i) * stride_q_descale_token]
|
||||
// k_descale_ptr[(k_origin + j) * stride_k_descale_token]
|
||||
// v_descale = *v_descale_ptr (loaded once per loop iter)
|
||||
const float* q_descale_ptr = nullptr,
|
||||
const index_t stride_q_descale_token = 0,
|
||||
const index_t stride_k_descale_token = 0) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -463,6 +472,87 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
// PER_TOKEN_HEAD: dequantize QK with per-row Q descale and per-col K descale.
|
||||
// s_acc[i,j] *= q_descale[(q_origin + i) * stride_q_descale_token]
|
||||
// * k_descale[(k_origin + j) * stride_k_descale_token]
|
||||
// q_descale_ptr / k_descale_ptr are already (batch, head)-resolved by the kernel.
|
||||
//
|
||||
// qr_async (no trload) on gfx9 is very tight on SGPRs; folding both row + col
|
||||
// global loads into a single 2-D sweep over s_acc made the compiler spill the
|
||||
// K/V buffer-SRDs to VGPRs and produce invalid asm. We instead split the
|
||||
// dequant into two 1-D sweeps so each pass keeps a single descale pointer
|
||||
// live and reuses the existing tile-distribution machinery without inflating
|
||||
// SGPR pressure inside the K/V async-load region.
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// PER_TOKEN_HEAD dequant: s_acc[i,j] *= q_descale[i] * k_descale[j].
|
||||
//
|
||||
// qr_async on gfx9 is extremely tight on SGPRs (the K/V buffer-SRDs
|
||||
// must stay live across iterations of the main K-loop). When threads
|
||||
// issue scalar global loads from k_descale_ptr inside the per-tile
|
||||
// 2-D sweep, the compiler runs out of SGPR budget and spills the K
|
||||
// SRD into VGPRs, producing invalid `buffer_load_dword v_dst, v[..]
|
||||
// ... lds` asm on gfx942.
|
||||
//
|
||||
// To break that dependency we stage both Q-row and K-col descales
|
||||
// through LDS first (one warp-strided load per descale tile), then
|
||||
// the per-element sweep is a pure LDS read + FP multiply. This keeps
|
||||
// the K SRD live-range contained inside the QK gemm, where the
|
||||
// existing schedule already accounts for it.
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
|
||||
const index_t q_row_base = q_origin.at(number<0>{});
|
||||
|
||||
// Mirror the V / randval path's sink-aware col-base derivation so we
|
||||
// never touch k_dram_block_window.get_window_origin() inside this
|
||||
// SGPR-tight region (that read leaks the K window state and was
|
||||
// independently observed to push the SRD into VGPRs).
|
||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
||||
const index_t k_col_base =
|
||||
in_sink_phase
|
||||
? (kv_load_start + i_total_loops * kN0)
|
||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// LDS staging tiles. Allocated per-block; kept small (kM0 + kN0 fp32).
|
||||
__shared__ float lds_q_descale[kM0];
|
||||
__shared__ float lds_k_descale[kN0];
|
||||
|
||||
const index_t tid_in_block =
|
||||
static_cast<index_t>(threadIdx.x + threadIdx.y * blockDim.x +
|
||||
threadIdx.z * blockDim.x * blockDim.y);
|
||||
const index_t threads_per_block =
|
||||
static_cast<index_t>(blockDim.x * blockDim.y * blockDim.z);
|
||||
|
||||
// Q-row descales (kM0 entries).
|
||||
for(index_t off = tid_in_block; off < kM0; off += threads_per_block)
|
||||
{
|
||||
lds_q_descale[off] =
|
||||
q_descale_ptr[(q_row_base + off) * stride_q_descale_token];
|
||||
}
|
||||
// K-col descales (kN0 entries).
|
||||
for(index_t off = tid_in_block; off < kN0; off += threads_per_block)
|
||||
{
|
||||
lds_k_descale[off] =
|
||||
k_descale_ptr[(k_col_base + off) * stride_k_descale_token];
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
const index_t i = tile_idx.at(number<0>{});
|
||||
const index_t j = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
s_acc(i_j_idx) *= lds_q_descale[i] * lds_k_descale[j];
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// dequant
|
||||
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
@@ -785,12 +875,19 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv;
|
||||
v_descale = v_descale_ptr[kv_idx];
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
// PER_TOKEN_HEAD: V scale is per-head only; v_descale_ptr is already
|
||||
// (batch, head)-resolved by the kernel, so load the single scalar.
|
||||
v_descale = *v_descale_ptr;
|
||||
}
|
||||
// STAGE 3, KV gemm
|
||||
auto o_acc0 = decltype(o_acc){};
|
||||
clear_tile(o_acc0);
|
||||
|
||||
auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
return o_acc0;
|
||||
}
|
||||
@@ -879,7 +976,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
}
|
||||
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);
|
||||
|
||||
Reference in New Issue
Block a user