Support fp8 dynamic quantization for fmha (#3206)

* Support qscale for dynamic quant, remove static quant

* Support hdim=256

* Remove bias test case for fp8

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
rocking
2025-11-24 16:28:25 +08:00
committed by GitHub
parent 096f0a3b23
commit 5948dbffe4
17 changed files with 369 additions and 280 deletions

View File

@@ -0,0 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
};
template <BlockAttentionQuantScaleEnum>
struct BlockAttentionQuantScaleEnumToStr;
template <>
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::NO_SCALE>
{
static constexpr const char* name = "";
};
template <>
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR>
{
static constexpr const char* name = "pertensor";
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <string>
@@ -36,6 +37,7 @@ struct FmhaFwdKernel
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using RandValOutputDataType =
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
@@ -54,7 +56,7 @@ struct FmhaFwdKernel
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
@@ -112,7 +114,8 @@ struct FmhaFwdKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) +
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name)) + (kUseTrLoad ? "_trload" : "_ntrload");
#undef _SS_
#undef _TS_
// clang-format on
@@ -204,10 +207,11 @@ struct FmhaFwdKernel
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdFp8StaticQuantKargs
struct FmhaFwdCommonQScaleKargs
{
float scale_p;
float scale_o;
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};
struct FmhaFwdCommonLSEKargs
@@ -285,7 +289,9 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
@@ -309,7 +315,9 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
@@ -339,6 +347,9 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
@@ -349,8 +360,6 @@ struct FmhaFwdKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -408,7 +417,7 @@ struct FmhaFwdKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for qscale
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
batch_stride_q,
@@ -440,10 +449,11 @@ struct FmhaFwdKernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.scale_p = scale_p;
kargs.scale_o = scale_o;
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasDropout)
{
@@ -483,6 +493,9 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
@@ -493,8 +506,6 @@ struct FmhaFwdKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -530,6 +541,9 @@ struct FmhaFwdKernel
k_ptr,
v_ptr,
bias_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
@@ -540,8 +554,6 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
@@ -580,6 +592,9 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
@@ -590,8 +605,6 @@ struct FmhaFwdKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -627,6 +640,9 @@ struct FmhaFwdKernel
k_ptr,
v_ptr,
bias_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
@@ -637,8 +653,6 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
@@ -676,6 +690,9 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
@@ -688,8 +705,6 @@ struct FmhaFwdKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -741,7 +756,7 @@ struct FmhaFwdKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for qscale
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
{}, // placeholder for min_seqlen_q
@@ -772,10 +787,11 @@ struct FmhaFwdKernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.scale_p = scale_p;
kargs.scale_o = scale_o;
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasDropout)
{
@@ -818,6 +834,9 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
@@ -830,8 +849,6 @@ struct FmhaFwdKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -861,6 +878,9 @@ struct FmhaFwdKernel
k_ptr,
v_ptr,
bias_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
@@ -873,8 +893,6 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
@@ -907,6 +925,9 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
@@ -919,8 +940,6 @@ struct FmhaFwdKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -950,6 +969,9 @@ struct FmhaFwdKernel
k_ptr,
v_ptr,
bias_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
@@ -962,8 +984,6 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
logits_soft_cap,
stride_q,
stride_k,
@@ -1527,14 +1547,24 @@ struct FmhaFwdKernel
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
// TODO - move global load of descale to pipeline
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
float scale_s = kargs.scale_s * q_descale * k_descale;
float scale_p =
ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
float scale_o = v_descale / scale_p;
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{kargs.scale_o});
ck_tile::scales{scale_o});
else
return ck_tile::scales{kargs.scale_o};
return ck_tile::scales{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
@@ -1546,13 +1576,13 @@ struct FmhaFwdKernel
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
scale_s,
variant,
variant_params,
block_indices,

View File

@@ -60,7 +60,7 @@ struct BlockFmhaPipelineProblem
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
@@ -18,7 +19,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kHasBiasGrad_,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
BlockAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
struct TileFmhaTraits
@@ -32,7 +33,7 @@ struct TileFmhaTraits
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr auto QScaleEnum = QScaleEnum_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};