[CK_TILE][FMHA] Add sparse attention VSA (#3341)

* add sparse attention VSA

* fix the pre-commit

* Add jenga test and pre-commit

* add bf16 for vsa

* add jenga support bf16

* remove lse arg

* split kernel code to block & kernel

* fix the pre-commit

* fix the pre-commit

* fix the copyrights

* fix the copyright

* fix the copyright & rename block to pipeline

* fix the copyright and pipeline

* remove lse & dropout & add fmt

* fix the jenga&VSA code review

* remove the useless code & resolved the comments

* remove useless code

* remove useless code

* Clean up code

* Remove more unused code

* Re-format .hpp

* Refactor codegen scripts

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
jiangyon.ren
2026-01-31 00:59:47 +08:00
committed by GitHub
parent 2377a62837
commit 4d2f8c111e
22 changed files with 6058 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,73 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16": "FmhaSparseFwdFp16",
"bf16": "FmhaSparseFwdBf16",
}
_MASK_SIMPLIFIED_MAP = {
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no": "FmhaMasks::NoMask",
"causal": "FmhaMasks::CausalMask",
"generic": "FmhaMasks::GenericMask",
}
def get_mask_map(mask: str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
MODE_MAP = {"batch": "false"}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
}
PIPELINE_ENUM_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -0,0 +1,3 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,867 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
#include "kernel/fmha_fwd_jenga_kernel.hpp"
"""
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_jenga_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize( # fmt: skip
32,
32,
128,
128,
32,
128,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
64,
32,
128,
16,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# jenga fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async":
continue
k = FmhaFwdKernel(
F_idx=2,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,867 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
#include "kernel/fmha_fwd_vsa_kernel.hpp"
"""
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_vsa_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize( # fmt: skip
32,
32,
128,
128,
32,
128,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
64,
32,
128,
16,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# vsa fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async_vsa":
continue
k = FmhaFwdKernel(
F_idx=1,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")