mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE][FMHA] Add FP8 support for batch_prefill kernel (#3425)
* Add fp8bf16 support for batch_prefill
* Fix wrong scale_s re-compute logic in batch_prefill
* Fix wrong scale_s re-compute logic in fmha fwd
* Fix batch_prefill codegen error
* Remove no-longer used GetName() function
* Add fp8 logits=True instances
* Update CHANGELOG.md
[ROCm/composable_kernel commit: 1c3151963b]
This commit is contained in:
@@ -10,6 +10,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
|
||||
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
|
||||
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
|
||||
* Added FP8 KV cache support for FMHA batch prefill.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -24,8 +24,15 @@ from codegen.cpp_symbol_map import (
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8": 8,
|
||||
"fp8bf16": 8,
|
||||
"fp8fp32": 8,
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
@@ -108,7 +115,7 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
std::cout << ", {F_kname}" << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
@@ -494,6 +501,7 @@ class FmhaFwdKernel:
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_kname=self.name,
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
@@ -576,10 +584,14 @@ class FmhaFwdKernel:
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -589,9 +601,9 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
qscale = "no"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
@@ -599,10 +611,16 @@ class KernelComponentFactory:
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -612,7 +630,7 @@ class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
if 128 in result.keys():
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
@@ -695,15 +713,14 @@ def get_fwd_blobs(
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
|
||||
@@ -1017,7 +1017,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["f"],
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
|
||||
@@ -500,6 +500,9 @@ struct fmha_batch_prefill_args
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* q_descale_ptr;
|
||||
const void* k_descale_ptr;
|
||||
const void* v_descale_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
@@ -1118,6 +1121,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
@@ -1166,6 +1172,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
|
||||
@@ -36,6 +36,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using RandValOutputDataType =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
|
||||
@@ -61,52 +62,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
using bfs = typename FmhaPipeline::BlockFmhaShape;
|
||||
using g0br = typename bfs::Gemm0BlockWarps;
|
||||
using g1br = typename bfs::Gemm1BlockWarps;
|
||||
using g0wt = typename bfs::Gemm0WarpTile;
|
||||
using g1wt = typename bfs::Gemm1WarpTile;
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadSeqLenQ) n += "s";
|
||||
if (kPadSeqLenK) n += "sk";
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
|
||||
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
|
||||
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) +
|
||||
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name));
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
struct FmhaFwdEmptyKargs
|
||||
@@ -211,6 +166,13 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
template <typename T>
|
||||
@@ -274,8 +236,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -292,8 +257,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -315,6 +283,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -396,6 +367,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
batch_stride_q,
|
||||
@@ -428,6 +400,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -463,6 +441,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
@@ -539,6 +520,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
@@ -568,6 +550,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -1055,37 +1043,96 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
const float scale_s = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
|
||||
return kargs.scale_s * q_descale * k_descale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.scale_s;
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
dropout);
|
||||
auto o_acc_tile = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
|
||||
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
float scale_o = v_descale / scale_p;
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else
|
||||
return ck_tile::scales{scale_o};
|
||||
}();
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
kargs.kv_page_indices,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
|
||||
@@ -1499,14 +1499,28 @@ struct FmhaFwdKernel
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
const float scale_s = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
|
||||
return kargs.scale_s * q_descale * k_descale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.scale_s;
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1516,11 +1530,8 @@ struct FmhaFwdKernel
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
|
||||
float scale_s = kargs.scale_s * q_descale * k_descale;
|
||||
float scale_p =
|
||||
ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
float scale_o = v_descale / scale_p;
|
||||
@@ -1548,7 +1559,7 @@ struct FmhaFwdKernel
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
@@ -1565,7 +1576,7 @@ struct FmhaFwdKernel
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
|
||||
Reference in New Issue
Block a user