Fix batch prefill compile fail in aiter (#3279)

* Fix batch prefill aiter compile fail

* Fix compile error
This commit is contained in:
rocking
2025-11-25 09:46:32 +08:00
committed by GitHub
parent de6a9590ab
commit 229d43ea0c
2 changed files with 55 additions and 102 deletions

View File

@@ -20,6 +20,8 @@ from codegen.cpp_symbol_map import (
FWD_DTYPE_MAP,
BOOL_MAP,
PIPELINE_ENUM_MAP,
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.utils import update_file
@@ -60,7 +62,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_qscale},
{F_occupancy}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -98,7 +100,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
#include <iostream>
@@ -175,9 +177,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -216,7 +218,7 @@ class FmhaFwdApiTrait:
bias: str #
lse: str #
dropout: str
squant: str #
qscale: str #
spad: str
skpad: str
dpad: str
@@ -227,7 +229,7 @@ class FmhaFwdApiTrait:
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
@@ -312,7 +314,7 @@ class FmhaFwdPipeline:
F_bias: str # true/false
F_lse: str #
F_dropout: str #
F_squant: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@@ -370,10 +372,10 @@ class FmhaFwdPipeline:
else:
n += "_ndropout"
if self.F_squant == "t":
n += "_squant"
if self.F_qscale != "no":
n += f"_{self.F_qscale}"
else:
n += "_nsquant"
n += "_nqscale"
return n
@@ -413,7 +415,8 @@ class FmhaFwdApiPool:
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_dropout=BOOL_MAP[trait.dropout],
F_squant=BOOL_MAP[trait.squant],
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
F_qscale=QSCALE_MAP[trait.qscale],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
@@ -522,7 +525,7 @@ class FmhaFwdKernel:
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -562,7 +565,7 @@ class FmhaFwdKernel:
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
qscale=self.F_pipeline.F_qscale,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
@@ -587,7 +590,7 @@ class KernelComponentFactory:
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = "t" if dtype == "fp8" else "f"
qscale = "no"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, lse, dropout in itertools.product(
@@ -597,10 +600,10 @@ class KernelComponentFactory:
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
else:
assert False
return pipelines
@@ -672,7 +675,7 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
@@ -680,7 +683,7 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
@@ -688,7 +691,7 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
@@ -696,7 +699,7 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
@@ -704,7 +707,7 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_qscale == "no"
if not cond:
continue

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <string>
@@ -53,7 +54,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
@@ -99,7 +100,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) +
(QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr<QScaleEnum>::name));
#undef _SS_
#undef _TS_
// clang-format on
@@ -202,12 +204,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdFp8StaticQuantKargs
{
float scale_p;
float scale_o;
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
@@ -278,9 +274,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -297,9 +292,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
{
const int32_t* seqstart_q_ptr;
ck_tile::index_t batch_stride_k;
@@ -337,8 +331,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t page_block_size,
#endif
float scale_s,
float scale_p,
float scale_o,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -401,7 +395,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
batch_stride_q,
@@ -433,11 +426,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_p = scale_p;
kargs.scale_o = scale_o;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -489,8 +477,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t page_block_size,
#endif
float scale_s,
float scale_p,
float scale_o,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
@@ -548,7 +536,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
@@ -577,11 +564,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_p = scale_p;
kargs.scale_o = scale_o;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
@@ -1082,55 +1064,23 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
dropout);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
dropout);
}
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
dropout);
}();
// O DRAM and O DRAM window