[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)

[CK_TILE] Add SageAttention v2 forward kernel with
 multi-granularity quantization (#6574)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

Add a CK_TILE forward kernel implementing [SageAttention
v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that
applies multi-granularity quantization to Q/K/V before computing
attention, trading minimal accuracy loss for higher throughput on
low-precision hardware.

### Quantization design

| Tensor | Supported data types | Scale granularity options |
|--------|---------------------|--------------------------|
| Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(32 tokens), per-thread (4 tokens) |
| K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(64 tokens), per-thread (16 tokens) |
| V | fp8 | per-channel (always) |
| O | bf16 | — |

Three precision combinations are supported: `fp8/bf16` (QKV fp8, O
bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK
int4, V fp8, O bf16).

### Architecture support

- **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set
- **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for
fp8-family dtypes)

### Implementation

- Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC`
(async copy)
- Masking support: no mask, causal (top-left / bottom-right), and
generic windowed
- Batch and group (variable-length) modes
- Head dimension: d=128, d_v=128
- Python codegen under `example/ck_tile/49_sageattention/codegen/`
generates kernel instances per target/dtype/tile combination
- Smoke tests included via `tile_example_sageattn_fwd`

### Test commands

\`\`\`bash
# fp8 QKV
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=fp8bf16 -qscale=3 -init=3

# int8 QK, fp8 V
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=i8fp8bf16 -qscale=3 -init=3
\`\`\`

\`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
ltqin
2026-04-30 18:33:36 +00:00
committed by assistant-librarian[bot]
parent e8d64ad5c6
commit de0a61e5c2
30 changed files with 7809 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,42 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from dataclasses import dataclass, field
from typing import Any, List, Callable
@dataclass(frozen=True)
class ArchTrait:
name: str
preprocessor_check: str = field(default=None)
device_name_check: str = field(default=None)
tag: str = field(default=None)
filename_suffix: str = field(default=None)
def __post_init__(self):
if self.preprocessor_check is None:
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
if self.device_name_check is None:
object.__setattr__(
self,
"device_name_check",
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
)
if self.tag is None:
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
if self.filename_suffix is None:
object.__setattr__(self, "filename_suffix", f"_{self.name}")
def get_factories_for_targets(
targets: List[str], get_factory: Callable[[str], Any]
) -> List[Any]:
factories = dict()
for target in targets:
factory = get_factory(target)
factories[factory.arch.name] = factory
# Place more specific architectures first
factories = sorted(
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
)
return factories

View File

@@ -0,0 +1,4 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder

View File

@@ -0,0 +1,103 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16": "SageAttentionFwdFp16",
"bf16": "SageAttentionFwdBf16",
"fp8bf16": "SageAttentionFwdFp8Bf16",
"i8fp8bf16": "SageAttentionFwdI8Fp8Bf16",
"i4fp8bf16": "SageAttentionFwdI4Fp8Bf16",
}
_MASK_SIMPLIFIED_MAP = {
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no": "SageAttnMasks::NoMask",
"causal": "SageAttnMasks::CausalMask",
"generic": "SageAttnMasks::GenericMask",
}
def get_mask_map(mask_impl: str):
if mask_impl == "generic":
return _MASK_MAP
elif mask_impl == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_impl(mask: str) -> str:
return "simplified" if mask.startswith("s_") else "generic"
def get_mask_cpp_type(mask: str) -> str:
return get_mask_map(get_mask_impl(mask))[mask]
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
def get_mask_cpp_check_expr(mask: str) -> str:
return get_mask_check_map(get_mask_impl(mask))[mask]
QSCALE_MAP = {
"no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE",
"perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP",
"perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"perwarp": "quant_scale_enum::perwarp",
"perthread": "quant_scale_enum::perthread",
}
MODE_MAP = {"batch": "false", "group": "true"}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr": "ck_tile::BlockSageAttentionPipelineQRKSVS",
"qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS",
"qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -0,0 +1,2 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,992 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
import fnmatch
import itertools
import os
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, ClassVar, Iterable, List, Optional, Tuple
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
LAYOUT_MAP,
BOOL_MAP,
PIPELINE_MAP,
PIPELINE_ENUM_MAP,
MODE_MAP,
FWD_DTYPE_MAP,
get_mask_map,
get_mask_cpp_type,
get_mask_cpp_check_expr,
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
DTYPE_BITS = {
"fp16": 16,
"bf16": 16,
"fp8bf16": 8,
"i8fp8bf16": 8,
"i4fp8bf16": 4,
}
K0_MAX_SUBMAX_MAP = {
32: 32,
48: 48,
64: 64,
80: 96,
96: 128,
128: 128,
192: 192,
256: 256,
}
SAGEATTN_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
#include "sageattn_fwd.hpp"
"""
SAGEATTN_FWD_KERNEL_BODY_TEMPLATE = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using sageattn_dtype = {F_dtype};
using sageattn_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using sageattn_shape = ck_tile::TileSageAttnShape<sageattn_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using sageattn_traits = ck_tile::TileSageAttnTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_qscale},
{F_occupancy},
{F_skip}>;
using sageattn_variant = ck_tile::ComposedAttention<false * ck_tile::LOGITS_SOFT_CAP, true>;
using sageattn_mask_type = {F_mask};
using sageattn_pipeline_problem = ck_tile::BlockSageAttnPipelineProblem<
typename SageAttentionFwdTypeConfig<sageattn_dtype>::QDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::KDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::VDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SMPLComputeDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::PDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
sageattn_shape,
{F_mode},
sageattn_variant,
sageattn_mask_type,
sageattn_traits>;
using sageattn_pipeline = {F_pipeline}<
sageattn_pipeline_problem>;
using sageattn_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
{F_spad}, {F_dvpad}>>;
using sageattn_kernel = {F_kernel}<sageattn_pipeline, sageattn_epilogue>;
using trait = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, sageattn_mask_type, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
template<>
float sageattn_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, sageattn_fwd_args a)
{{
using k_ = sageattn_kernel;
if(s.log_level_ > 0)
std::cout << ", {F_kname}" << std::flush;
auto [kargs, grids] = {F_kargs_creator}<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
SAGEATTN_FWD_API_FILENAME = "sageattn_fwd_api.cpp"
SAGEATTN_FWD_API_HEADER = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include <cstdio>
#include <hip/hip_runtime.h>
#include "sageattn_fwd.hpp"
namespace {
bool get_num_cus(unsigned& num_cus) {
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {
fprintf(stderr, "failed to get device");
return false;
}
hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {
fprintf(stderr, "failed to get device properties");
return false;
}
num_cus = props.multiProcessorCount;
return true;
}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}
} // namespace
"""
SAGEATTN_FWD_API_FUNC_TEMPLATE = """
namespace {{
float {F_func_name}([[maybe_unused]] sageattn_fwd_traits t, [[maybe_unused]] sageattn_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if(!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
}} // namespace
"""
SAGEATTN_FWD_API_FOOTER_TEMPLATE = """
// Public API entry point - unified for SageAttention
float sageattn_fwd(sageattn_fwd_traits traits, sageattn_fwd_args args, const ck_tile::stream_config& config) {
return sageattn_fwd_impl(traits, args, config);
}
"""
SAGEATTN_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
{F_dtype_case}
}}
"""
SAGEATTN_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
{F_hdim_case}
}}
"""
SAGEATTN_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
SAGEATTN_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return sageattn_fwd_<trait_, {F_arch.tag}>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class SageAttnFwdApiTrait:
arch: ArchTrait
pipeline_tag: str
# sync with sageattn_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
mask: str
qscale: str #
spad: str
skpad: str
dpad: str
dvpad: str
skip: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.mask}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
return "true"
elif self.pipeline_tag in ["qr", "qs"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_q % {self.bm0} == 0"
else:
assert False
def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0 or self.bm0 == 64:
return "true/*fall back to largest tile*/"
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
else:
assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == "t":
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {bk0submax} == 0"
else:
assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == "t":
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
# F_dvpad="f": Causal mask requires hdim_v <= kN1 (num_tile_n1 == 1 for tile reversal)
# Non-causal requires hdim_v % kN1 == 0 (epilogue writes full tiles)
if self.mask == "causal":
return f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v <= {self.bn1})"
else:
return (
f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v % {self.bn1} == 0)"
)
else:
assert False
@dataclass
class SageAttnFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_qscale: str # no/pertensor/blockscale/perwarp/perthread
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
if self.F_skip == "t":
n += "_skip"
else:
n += "_nskip"
if self.F_qscale != "no":
n += f"_{self.F_qscale}"
else:
n += "_nqscale"
return n
class SageAttnFwdApiPool:
def __init__(self):
self.pool = OrderedDict()
def register_traits(self, trait: SageAttnFwdApiTrait) -> None:
hdim = trait.hdim, trait.bn1
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
def get_num_traits(
self, filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None
) -> int:
if filter_fn is None:
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
return True
filter_fn = accept_all
return sum(
sum(1 for trait in pool_by_hdim if filter_fn(trait))
for pool_by_arch in self.pool.values()
for pool_by_dtype in pool_by_arch.values()
for pool_by_hdim in pool_by_dtype.values()
)
def render(
self,
func_name,
filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None,
) -> str:
if filter_fn is None:
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
return True
filter_fn = accept_all
def has_traits(node) -> bool:
"""Recursively traverse nested OrderedDicts and lists to determine if any SageAttnFwdApiTrait satisfies filter_fn()."""
if isinstance(node, list):
return any(filter_fn(elem) for elem in node)
elif isinstance(node, OrderedDict):
return any(has_traits(val) for val in node.values())
return False
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(
item for item in self.pool.items() if has_traits(item[1])
):
per_dtypes = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(
item for item in pool_by_arch.items() if has_traits(item[1])
):
per_hdim_case = str()
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
item for item in pool_by_dtype.items() if has_traits(item[1])
):
max_bm0 = max(
(t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0
)
inners = str()
for i_trait, trait in enumerate(
[trait for trait in pool_by_hdim if filter_fn(trait)]
):
inners += SAGEATTN_FWD_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_mask=get_mask_cpp_type(trait.mask),
F_mask_check=get_mask_cpp_check_expr(trait.mask),
F_skip=BOOL_MAP[trait.skip],
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
F_qscale=QSCALE_MAP[trait.qscale],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune(max_bm0),
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
per_hdim_case += SAGEATTN_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim_v,
F_inner_dispatch=indent(inners),
)
per_dtypes += SAGEATTN_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
per_arch += SAGEATTN_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
return SAGEATTN_FWD_API_FUNC_TEMPLATE.format(
F_func_name=func_name, F_dispatch=indent(per_arch)
)
@dataclass
class SageAttnFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class SageAttnFwdKernel:
F_arch: ArchTrait
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: SageAttnFwdTileSize
F_pipeline: SageAttnFwdPipeline
_KERNEL_HEADER: ClassVar[str] = SAGEATTN_FWD_KERNEL_HEADER
_KERNEL_BODY_TEMPLATE: ClassVar[str] = SAGEATTN_FWD_KERNEL_BODY_TEMPLATE
@classmethod
def _get_cpp_kernel_class_name(cls, pipeline_tag):
return "ck_tile::SageAttnFwdKernel"
@classmethod
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
return "sageattn_fwd_create_kargs_and_grids"
def render(self) -> str:
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
F_kname=self.name,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_cpp_type(self.F_pipeline.F_mask),
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"sageattn_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
def api_trait(self) -> SageAttnFwdApiTrait:
return SageAttnFwdApiTrait(
arch=self.F_arch,
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
qscale=self.F_pipeline.F_qscale,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
@dataclass
class ProblemContext:
dtype: str
mode: str
hdim: int
hdim_v: int
@dataclass
class KernelContext:
tile: SageAttnFwdTileSize
pipeline: SageAttnFwdPipeline
mask_impl: str
CompatibilityRule = Callable[[ProblemContext, KernelContext], bool]
def is_compatible(
problem_ctx: ProblemContext,
kernel_ctx: KernelContext,
rules: Iterable[CompatibilityRule],
) -> bool:
return all(rule(problem_ctx, kernel_ctx) for rule in rules)
def create_kernel(
arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> SageAttnFwdKernel:
return SageAttnFwdKernel(
F_arch=arch,
F_dtype=problem_ctx.dtype,
F_mode=problem_ctx.mode,
F_hdim=problem_ctx.hdim,
F_tile=kernel_ctx.tile,
F_pipeline=kernel_ctx.pipeline,
)
class CompatibilityRuleFactory:
@staticmethod
def get_rules() -> List[CompatibilityRule]:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
if problem_ctx.mode == "group":
if (
kernel_ctx.pipeline.F_spad != "t"
or kernel_ctx.pipeline.F_skpad != "t"
):
return False
return True
return [check_mode]
class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
@classmethod
def get_rules(cls) -> List[CompatibilityRule]:
rules = CompatibilityRuleFactory.get_rules()
return rules
class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
pass
class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
arch = ArchTrait(
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
)
# Note: fp16 is not supported by SageAttention (only bf16 + fp8/int quantization)
_DT_BF16 = ("bf16",)
_DT_FP8BF16 = ("fp8bf16",)
_DT_I8FP8BF16 = ("i8fp8bf16",)
_DT_I4FP8BF16 = ("i4fp8bf16",)
@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_BF16 + cls._DT_FP8BF16 + cls._DT_I8FP8BF16 + cls._DT_I4FP8BF16
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_BF16:
return {
(128, 128) : [SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# gfx9 (non-gfx950): only F_bn0=64; F_bn0=128 variant is gfx950-only (see Gfx950).
return {
(128, 128): [
SageAttnFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
],
}
else:
raise ValueError(f"unsupported dtype={dtype}")
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@classmethod
def get_pipelines(
cls, dtype, hdim, hdim_v, receipt, mask_impl
) -> List[SageAttnFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let "t" padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in cls._DT_BF16:
qscale = "no"
skip = "f" # skip: only false
for mask, vlayout in itertools.product(
get_mask_map(mask_impl).keys(),
["row", "col"],
):
if hdim == 256 and hdim_v == 256:
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "f", "f", "f", "f", qscale, mask, skip)) # fmt: skip
# the below two is used for hdim vectorize load
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
else:
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
elif (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# no need lse kernels
skip = "f" # skip: only false
for mask, qscale, vlayout in itertools.product(
get_mask_map(mask_impl).keys(),
["no", "pertensor", "blockscale", "perwarp", "perthread"],
["row", "col"], # Support both row and col major layouts
):
if dtype in cls._DT_I4FP8BF16:
# int4 only uses sync pipeline (qr), pad_d="f" because packed types
# require alignment >= PackedSize which conflicts with kPadHeadDimQ=true
# forcing alignment to 1. Safe since hdim always matches tile size.
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
elif hdim == 64:
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
else:
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
# Packed types (int4) cannot use head-dim padding: the tile_window infrastructure
# forces alignment=1 when padding is enabled, but packed types need alignment >= PackedSize.
if dtype in cls._DT_I4FP8BF16:
for p in pipelines:
assert p.F_dpad == "f", (
f"int4 dtype '{dtype}' requires pad_d=false, got '{p.F_dpad}'"
)
assert p.F_dvpad == "f", (
f"int4 dtype '{dtype}' requires pad_dv=false, got '{p.F_dvpad}'"
)
return pipelines
class KernelComponentFactoryGfx950(
KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950
):
arch = ArchTrait("gfx950")
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# gfx950 fp8-family: F_bn0=128 tile only (gfx9 uses F_bn0=64 only).
return {
(128, 128): [
SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
],
}
return super().get_hdim_tile_size_dict(dtype)
class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9):
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
if dtype in cls._DT_BF16:
if (128, 128) in result.keys():
result[(128, 128)].insert(0, SageAttnFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result
def get_factory(target: str):
if os.environ.get("CK_TILE_SAGEATTN_FWD_CUSTOM_FACTORY", "0") == "1":
return CustomFactory
# Place more specific architectures first
if target.startswith("gfx950"):
return KernelComponentFactoryGfx950
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
raise Exception(f"Unsupported device target {target}")
@dataclass(frozen=True)
class Product:
name: str
rule: CompatibilityRule
def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
return self.rule(problem_ctx, kernel_ctx)
def get_product(receipt: int) -> Product:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
# bf16 (no quantization) should not have qscale
if problem_ctx.dtype == "bf16":
if kernel_ctx.pipeline.F_qscale != "no":
return False
return True
return Product(name="All tiles", rule=fit)
def get_fwd_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[SageAttnFwdApiPool, List[SageAttnFwdKernel]]:
gen = list()
api_pool = SageAttnFwdApiPool()
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()):
d = factory.get_hdim_tile_size_dict(dtype)
# for hdim_str, mode, mask, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for ((hdim, hdim_v), tiles), mode in itertools.product(
d.items(), MODE_MAP.keys()
):
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
for tile, next_tile in zip(tiles, tiles[1:]):
assert next_tile.F_bm0 >= tile.F_bm0, (
"Tiles must be ordered by increasing bm0"
)
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
problem_ctx = ProblemContext(
dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v
)
kernel_ctx = KernelContext(
tile=tile, pipeline=pipeline, mask_impl=mask_impl
)
rules = factory.get_rules()
product = get_product(receipt)
if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]):
continue
k = create_kernel(factory.arch, problem_ctx, kernel_ctx)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: SageAttnFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.render())
def write_fwd_api(
api_pool: SageAttnFwdApiPool,
autogen_dir: Path,
) -> None:
content = "".join(
[
SAGEATTN_FWD_API_HEADER,
api_pool.render("sageattn_fwd_impl"),
SAGEATTN_FWD_API_FOOTER_TEMPLATE,
]
)
update_file(autogen_dir / SAGEATTN_FWD_API_FILENAME, content)
def write_blobs(
targets: List[str],
output_dir: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
targets: List[str],
file_path: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write(
(file_path.parent / GEN_DIR / SAGEATTN_FWD_API_FILENAME).as_posix() + "\n"
)

View File

@@ -0,0 +1,70 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import dataclasses
import os.path as path
import textwrap
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
def indent(code: str, indent: str = " ") -> str:
return textwrap.indent(code, indent)
def if_(i: int) -> str:
return "if" if i == 0 else "else if"
def check_duplicates_and_paddings(traits, trait):
"""Check
* if the traits list does not contain a trait with the same parameters;
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
for example, f, _t_, f, t cannot be before f, _f_, f, t.
"""
fields = [f.name for f in dataclasses.fields(trait)]
pad_fields = [f for f in fields if "pad" in f]
non_pad_fields = [f for f in fields if "pad" not in f]
for prev_trait in traits:
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
continue
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
raise Exception(f"Duplicate found {trait}")
# Check if the previous kernel can be incorrectly used before the current one
# for example, f, _t_, f, t cannot be before f, _f_, f, t
is_prev_more_restrictive = False
is_curr_more_restrictive = False
for f in pad_fields:
prev_pad = getattr(prev_trait, f)
pad = getattr(trait, f)
if isinstance(prev_pad, str):
prev_pad = 1000000 if prev_pad == "f" else 1
pad = 1000000 if pad == "f" else 1
elif isinstance(prev_pad, int):
prev_pad = 1000000 if prev_pad == 0 else prev_pad
pad = 1000000 if pad == 0 else pad
else:
assert False
if prev_pad < pad:
is_prev_more_restrictive = True
elif prev_pad > pad:
is_curr_more_restrictive = True
if is_prev_more_restrictive and not is_curr_more_restrictive:
raise Exception(
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
)