mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Fix batch prefill compile fail in aiter (#3279)
* Fix batch prefill aiter compile fail * Fix compile error
This commit is contained in:
@@ -20,6 +20,8 @@ from codegen.cpp_symbol_map import (
|
||||
FWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
|
||||
@@ -60,7 +62,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_qscale},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
@@ -98,7 +100,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -175,9 +177,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -216,7 +218,7 @@ class FmhaFwdApiTrait:
|
||||
bias: str #
|
||||
lse: str #
|
||||
dropout: str
|
||||
squant: str #
|
||||
qscale: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
@@ -227,7 +229,7 @@ class FmhaFwdApiTrait:
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -312,7 +314,7 @@ class FmhaFwdPipeline:
|
||||
F_bias: str # true/false
|
||||
F_lse: str #
|
||||
F_dropout: str #
|
||||
F_squant: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@@ -370,10 +372,10 @@ class FmhaFwdPipeline:
|
||||
else:
|
||||
n += "_ndropout"
|
||||
|
||||
if self.F_squant == "t":
|
||||
n += "_squant"
|
||||
if self.F_qscale != "no":
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nsquant"
|
||||
n += "_nqscale"
|
||||
return n
|
||||
|
||||
|
||||
@@ -413,7 +415,8 @@ class FmhaFwdApiPool:
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
|
||||
F_qscale=QSCALE_MAP[trait.qscale],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
@@ -522,7 +525,7 @@ class FmhaFwdKernel:
|
||||
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -562,7 +565,7 @@ class FmhaFwdKernel:
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
qscale=self.F_pipeline.F_qscale,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
@@ -587,7 +590,7 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
qscale = "no"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
@@ -597,10 +600,10 @@ class KernelComponentFactory:
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -672,7 +675,7 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -680,7 +683,7 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
@@ -688,7 +691,7 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
@@ -696,7 +699,7 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
@@ -704,7 +707,7 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user