mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Support fp8 dynamic quantization for fmha (#3206)
* Support qscale for dynamic quant, remove static quant * Support hdim=256 * Remove bias test case for fp8 --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
};
|
||||
|
||||
template <BlockAttentionQuantScaleEnum>
|
||||
struct BlockAttentionQuantScaleEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::NO_SCALE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::PERTENSOR>
|
||||
{
|
||||
static constexpr const char* name = "pertensor";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <string>
|
||||
@@ -36,6 +37,7 @@ struct FmhaFwdKernel
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using RandValOutputDataType =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
|
||||
@@ -54,7 +56,7 @@ struct FmhaFwdKernel
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
@@ -112,7 +114,8 @@ struct FmhaFwdKernel
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) +
|
||||
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name)) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -204,10 +207,11 @@ struct FmhaFwdKernel
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8StaticQuantKargs
|
||||
struct FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
@@ -285,7 +289,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -309,7 +315,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -339,6 +347,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -349,8 +360,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -408,7 +417,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
batch_stride_q,
|
||||
@@ -440,10 +449,11 @@ struct FmhaFwdKernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
kargs.scale_o = scale_o;
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -483,6 +493,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -493,8 +506,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -530,6 +541,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -540,8 +554,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -580,6 +592,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -590,8 +605,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -627,6 +640,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -637,8 +653,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -676,6 +690,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -688,8 +705,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -741,7 +756,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // placeholder for min_seqlen_q
|
||||
@@ -772,10 +787,11 @@ struct FmhaFwdKernel
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.scale_p = scale_p;
|
||||
kargs.scale_o = scale_o;
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -818,6 +834,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -830,8 +849,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -861,6 +878,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -873,8 +893,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -907,6 +925,9 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -919,8 +940,6 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
@@ -950,6 +969,9 @@ struct FmhaFwdKernel
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
@@ -962,8 +984,6 @@ struct FmhaFwdKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
logits_soft_cap,
|
||||
stride_q,
|
||||
stride_k,
|
||||
@@ -1527,14 +1547,24 @@ struct FmhaFwdKernel
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
|
||||
float scale_s = kargs.scale_s * q_descale * k_descale;
|
||||
float scale_p =
|
||||
ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
float scale_o = v_descale / scale_p;
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{kargs.scale_o});
|
||||
ck_tile::scales{scale_o});
|
||||
else
|
||||
return ck_tile::scales{kargs.scale_o};
|
||||
return ck_tile::scales{scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
@@ -1546,13 +1576,13 @@ struct FmhaFwdKernel
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
|
||||
@@ -60,7 +60,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -18,7 +19,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileFmhaTraits
|
||||
@@ -32,7 +33,7 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user