mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)
[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e8d64ad5c6
commit
de0a61e5c2
2
example/ck_tile/49_sageattention/codegen/__init__.py
Normal file
2
example/ck_tile/49_sageattention/codegen/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
42
example/ck_tile/49_sageattention/codegen/arch.py
Normal file
42
example/ck_tile/49_sageattention/codegen/arch.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArchTrait:
|
||||
name: str
|
||||
preprocessor_check: str = field(default=None)
|
||||
device_name_check: str = field(default=None)
|
||||
tag: str = field(default=None)
|
||||
filename_suffix: str = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.preprocessor_check is None:
|
||||
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
|
||||
if self.device_name_check is None:
|
||||
object.__setattr__(
|
||||
self,
|
||||
"device_name_check",
|
||||
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
|
||||
)
|
||||
if self.tag is None:
|
||||
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
|
||||
if self.filename_suffix is None:
|
||||
object.__setattr__(self, "filename_suffix", f"_{self.name}")
|
||||
|
||||
|
||||
def get_factories_for_targets(
|
||||
targets: List[str], get_factory: Callable[[str], Any]
|
||||
) -> List[Any]:
|
||||
factories = dict()
|
||||
for target in targets:
|
||||
factory = get_factory(target)
|
||||
factories[factory.arch.name] = factory
|
||||
# Place more specific architectures first
|
||||
factories = sorted(
|
||||
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
|
||||
)
|
||||
return factories
|
||||
4
example/ck_tile/49_sageattention/codegen/cmake_config.py
Normal file
4
example/ck_tile/49_sageattention/codegen/cmake_config.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
103
example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py
Normal file
103
example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16": "SageAttentionFwdFp16",
|
||||
"bf16": "SageAttentionFwdBf16",
|
||||
"fp8bf16": "SageAttentionFwdFp8Bf16",
|
||||
"i8fp8bf16": "SageAttentionFwdI8Fp8Bf16",
|
||||
"i4fp8bf16": "SageAttentionFwdI4Fp8Bf16",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no": "SageAttnMasks::NoMask",
|
||||
"causal": "SageAttnMasks::CausalMask",
|
||||
"generic": "SageAttnMasks::GenericMask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_map(mask_impl: str):
|
||||
if mask_impl == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask_impl == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_impl(mask: str) -> str:
|
||||
return "simplified" if mask.startswith("s_") else "generic"
|
||||
|
||||
|
||||
def get_mask_cpp_type(mask: str) -> str:
|
||||
return get_mask_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic": "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no": "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask": "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_check_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
return get_mask_check_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP",
|
||||
"perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"perwarp": "quant_scale_enum::perwarp",
|
||||
"perthread": "quant_scale_enum::perthread",
|
||||
}
|
||||
|
||||
MODE_MAP = {"batch": "false", "group": "true"}
|
||||
|
||||
LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockSageAttentionPipelineQRKSVS",
|
||||
"qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS",
|
||||
"qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
True: "true",
|
||||
False: "false",
|
||||
}
|
||||
2
example/ck_tile/49_sageattention/codegen/ops/__init__.py
Normal file
2
example/ck_tile/49_sageattention/codegen/ops/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
992
example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py
Normal file
992
example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py
Normal file
@@ -0,0 +1,992 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import copy
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, ClassVar, Iterable, List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
MODE_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
get_mask_map,
|
||||
get_mask_cpp_type,
|
||||
get_mask_cpp_check_expr,
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8bf16": 8,
|
||||
"i8fp8bf16": 8,
|
||||
"i4fp8bf16": 4,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32: 32,
|
||||
48: 48,
|
||||
64: 64,
|
||||
80: 96,
|
||||
96: 128,
|
||||
128: 128,
|
||||
192: 192,
|
||||
256: 256,
|
||||
}
|
||||
|
||||
SAGEATTN_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "sageattn_fwd.hpp"
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_KERNEL_BODY_TEMPLATE = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using sageattn_dtype = {F_dtype};
|
||||
|
||||
using sageattn_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using sageattn_shape = ck_tile::TileSageAttnShape<sageattn_block_tile,
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using sageattn_traits = ck_tile::TileSageAttnTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using sageattn_variant = ck_tile::ComposedAttention<false * ck_tile::LOGITS_SOFT_CAP, true>;
|
||||
|
||||
using sageattn_mask_type = {F_mask};
|
||||
|
||||
using sageattn_pipeline_problem = ck_tile::BlockSageAttnPipelineProblem<
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::QDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::KDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::VDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SMPLComputeDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::PDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
|
||||
sageattn_shape,
|
||||
{F_mode},
|
||||
sageattn_variant,
|
||||
sageattn_mask_type,
|
||||
sageattn_traits>;
|
||||
|
||||
using sageattn_pipeline = {F_pipeline}<
|
||||
sageattn_pipeline_problem>;
|
||||
|
||||
using sageattn_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using sageattn_kernel = {F_kernel}<sageattn_pipeline, sageattn_epilogue>;
|
||||
|
||||
|
||||
using trait = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, sageattn_mask_type, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
template<>
|
||||
float sageattn_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, sageattn_fwd_args a)
|
||||
{{
|
||||
using k_ = sageattn_kernel;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", {F_kname}" << std::flush;
|
||||
auto [kargs, grids] = {F_kargs_creator}<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_FILENAME = "sageattn_fwd_api.cpp"
|
||||
SAGEATTN_FWD_API_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "sageattn_fwd.hpp"
|
||||
|
||||
namespace {
|
||||
bool get_num_cus(unsigned& num_cus) {
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}
|
||||
} // namespace
|
||||
"""
|
||||
SAGEATTN_FWD_API_FUNC_TEMPLATE = """
|
||||
namespace {{
|
||||
float {F_func_name}([[maybe_unused]] sageattn_fwd_traits t, [[maybe_unused]] sageattn_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if(!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
}} // namespace
|
||||
"""
|
||||
SAGEATTN_FWD_API_FOOTER_TEMPLATE = """
|
||||
// Public API entry point - unified for SageAttention
|
||||
float sageattn_fwd(sageattn_fwd_traits traits, sageattn_fwd_args args, const ck_tile::stream_config& config) {
|
||||
return sageattn_fwd_impl(traits, args, config);
|
||||
}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return sageattn_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdApiTrait:
|
||||
arch: ArchTrait
|
||||
pipeline_tag: str
|
||||
# sync with sageattn_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
mask: str
|
||||
qscale: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
skip: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.mask}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bm0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
# F_dvpad="f": Causal mask requires hdim_v <= kN1 (num_tile_n1 == 1 for tile reversal)
|
||||
# Non-causal requires hdim_v % kN1 == 0 (epilogue writes full tiles)
|
||||
if self.mask == "causal":
|
||||
return f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v <= {self.bn1})"
|
||||
else:
|
||||
return (
|
||||
f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v % {self.bn1} == 0)"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_qscale: str # no/pertensor/blockscale/perwarp/perthread
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
if self.F_skip == "t":
|
||||
n += "_skip"
|
||||
else:
|
||||
n += "_nskip"
|
||||
|
||||
if self.F_qscale != "no":
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class SageAttnFwdApiPool:
|
||||
def __init__(self):
|
||||
self.pool = OrderedDict()
|
||||
|
||||
def register_traits(self, trait: SageAttnFwdApiTrait) -> None:
|
||||
hdim = trait.hdim, trait.bn1
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
def get_num_traits(
|
||||
self, filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None
|
||||
) -> int:
|
||||
if filter_fn is None:
|
||||
|
||||
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
|
||||
return True
|
||||
|
||||
filter_fn = accept_all
|
||||
|
||||
return sum(
|
||||
sum(1 for trait in pool_by_hdim if filter_fn(trait))
|
||||
for pool_by_arch in self.pool.values()
|
||||
for pool_by_dtype in pool_by_arch.values()
|
||||
for pool_by_hdim in pool_by_dtype.values()
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
func_name,
|
||||
filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None,
|
||||
) -> str:
|
||||
if filter_fn is None:
|
||||
|
||||
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
|
||||
return True
|
||||
|
||||
filter_fn = accept_all
|
||||
|
||||
def has_traits(node) -> bool:
|
||||
"""Recursively traverse nested OrderedDicts and lists to determine if any SageAttnFwdApiTrait satisfies filter_fn()."""
|
||||
if isinstance(node, list):
|
||||
return any(filter_fn(elem) for elem in node)
|
||||
elif isinstance(node, OrderedDict):
|
||||
return any(has_traits(val) for val in node.values())
|
||||
return False
|
||||
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(
|
||||
item for item in self.pool.items() if has_traits(item[1])
|
||||
):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(
|
||||
item for item in pool_by_arch.items() if has_traits(item[1])
|
||||
):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
|
||||
item for item in pool_by_dtype.items() if has_traits(item[1])
|
||||
):
|
||||
max_bm0 = max(
|
||||
(t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0
|
||||
)
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(
|
||||
[trait for trait in pool_by_hdim if filter_fn(trait)]
|
||||
):
|
||||
inners += SAGEATTN_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_mask=get_mask_cpp_type(trait.mask),
|
||||
F_mask_check=get_mask_cpp_check_expr(trait.mask),
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
|
||||
F_qscale=QSCALE_MAP[trait.qscale],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune(max_bm0),
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
per_hdim_case += SAGEATTN_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_hdim_v=hdim_v,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
per_dtypes += SAGEATTN_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
per_arch += SAGEATTN_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
return SAGEATTN_FWD_API_FUNC_TEMPLATE.format(
|
||||
F_func_name=func_name, F_dispatch=indent(per_arch)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdKernel:
|
||||
F_arch: ArchTrait
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: SageAttnFwdTileSize
|
||||
F_pipeline: SageAttnFwdPipeline
|
||||
|
||||
_KERNEL_HEADER: ClassVar[str] = SAGEATTN_FWD_KERNEL_HEADER
|
||||
_KERNEL_BODY_TEMPLATE: ClassVar[str] = SAGEATTN_FWD_KERNEL_BODY_TEMPLATE
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kernel_class_name(cls, pipeline_tag):
|
||||
return "ck_tile::SageAttnFwdKernel"
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
|
||||
return "sageattn_fwd_create_kargs_and_grids"
|
||||
|
||||
def render(self) -> str:
|
||||
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
|
||||
F_kname=self.name,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_cpp_type(self.F_pipeline.F_mask),
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
|
||||
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"sageattn_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> SageAttnFwdApiTrait:
|
||||
return SageAttnFwdApiTrait(
|
||||
arch=self.F_arch,
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
qscale=self.F_pipeline.F_qscale,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemContext:
|
||||
dtype: str
|
||||
mode: str
|
||||
hdim: int
|
||||
hdim_v: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelContext:
|
||||
tile: SageAttnFwdTileSize
|
||||
pipeline: SageAttnFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
|
||||
CompatibilityRule = Callable[[ProblemContext, KernelContext], bool]
|
||||
|
||||
|
||||
def is_compatible(
|
||||
problem_ctx: ProblemContext,
|
||||
kernel_ctx: KernelContext,
|
||||
rules: Iterable[CompatibilityRule],
|
||||
) -> bool:
|
||||
return all(rule(problem_ctx, kernel_ctx) for rule in rules)
|
||||
|
||||
|
||||
def create_kernel(
|
||||
arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> SageAttnFwdKernel:
|
||||
return SageAttnFwdKernel(
|
||||
F_arch=arch,
|
||||
F_dtype=problem_ctx.dtype,
|
||||
F_mode=problem_ctx.mode,
|
||||
F_hdim=problem_ctx.hdim,
|
||||
F_tile=kernel_ctx.tile,
|
||||
F_pipeline=kernel_ctx.pipeline,
|
||||
)
|
||||
|
||||
|
||||
class CompatibilityRuleFactory:
|
||||
@staticmethod
|
||||
def get_rules() -> List[CompatibilityRule]:
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
if problem_ctx.mode == "group":
|
||||
if (
|
||||
kernel_ctx.pipeline.F_spad != "t"
|
||||
or kernel_ctx.pipeline.F_skpad != "t"
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
return [check_mode]
|
||||
|
||||
|
||||
class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = CompatibilityRuleFactory.get_rules()
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
|
||||
pass
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
arch = ArchTrait(
|
||||
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
|
||||
)
|
||||
|
||||
# Note: fp16 is not supported by SageAttention (only bf16 + fp8/int quantization)
|
||||
_DT_BF16 = ("bf16",)
|
||||
_DT_FP8BF16 = ("fp8bf16",)
|
||||
_DT_I8FP8BF16 = ("i8fp8bf16",)
|
||||
_DT_I4FP8BF16 = ("i4fp8bf16",)
|
||||
|
||||
@classmethod
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return cls._DT_BF16 + cls._DT_FP8BF16 + cls._DT_I8FP8BF16 + cls._DT_I4FP8BF16
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if dtype in cls._DT_BF16:
|
||||
return {
|
||||
(128, 128) : [SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# gfx9 (non-gfx950): only F_bn0=64; F_bn0=128 variant is gfx950-only (see Gfx950).
|
||||
return {
|
||||
(128, 128): [
|
||||
SageAttnFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype={dtype}")
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@classmethod
|
||||
def get_pipelines(
|
||||
cls, dtype, hdim, hdim_v, receipt, mask_impl
|
||||
) -> List[SageAttnFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let "t" padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in cls._DT_BF16:
|
||||
qscale = "no"
|
||||
skip = "f" # skip: only false
|
||||
for mask, vlayout in itertools.product(
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["row", "col"],
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "f", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
elif (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# no need lse kernels
|
||||
skip = "f" # skip: only false
|
||||
for mask, qscale, vlayout in itertools.product(
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no", "pertensor", "blockscale", "perwarp", "perthread"],
|
||||
["row", "col"], # Support both row and col major layouts
|
||||
):
|
||||
if dtype in cls._DT_I4FP8BF16:
|
||||
# int4 only uses sync pipeline (qr), pad_d="f" because packed types
|
||||
# require alignment >= PackedSize which conflicts with kPadHeadDimQ=true
|
||||
# forcing alignment to 1. Safe since hdim always matches tile size.
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
elif hdim == 64:
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
|
||||
# Packed types (int4) cannot use head-dim padding: the tile_window infrastructure
|
||||
# forces alignment=1 when padding is enabled, but packed types need alignment >= PackedSize.
|
||||
if dtype in cls._DT_I4FP8BF16:
|
||||
for p in pipelines:
|
||||
assert p.F_dpad == "f", (
|
||||
f"int4 dtype '{dtype}' requires pad_d=false, got '{p.F_dpad}'"
|
||||
)
|
||||
assert p.F_dvpad == "f", (
|
||||
f"int4 dtype '{dtype}' requires pad_dv=false, got '{p.F_dvpad}'"
|
||||
)
|
||||
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx950(
|
||||
KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950
|
||||
):
|
||||
arch = ArchTrait("gfx950")
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# gfx950 fp8-family: F_bn0=128 tile only (gfx9 uses F_bn0=64 only).
|
||||
return {
|
||||
(128, 128): [
|
||||
SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
|
||||
],
|
||||
}
|
||||
return super().get_hdim_tile_size_dict(dtype)
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9):
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
if dtype in cls._DT_BF16:
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(0, SageAttnFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
if os.environ.get("CK_TILE_SAGEATTN_FWD_CUSTOM_FACTORY", "0") == "1":
|
||||
return CustomFactory
|
||||
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx950"):
|
||||
return KernelComponentFactoryGfx950
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Product:
|
||||
name: str
|
||||
rule: CompatibilityRule
|
||||
|
||||
def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
return self.rule(problem_ctx, kernel_ctx)
|
||||
|
||||
|
||||
def get_product(receipt: int) -> Product:
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
# bf16 (no quantization) should not have qscale
|
||||
if problem_ctx.dtype == "bf16":
|
||||
if kernel_ctx.pipeline.F_qscale != "no":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return Product(name="All tiles", rule=fit)
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[SageAttnFwdApiPool, List[SageAttnFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = SageAttnFwdApiPool()
|
||||
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
# for hdim_str, mode, mask, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(
|
||||
d.items(), MODE_MAP.keys()
|
||||
):
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
for tile, next_tile in zip(tiles, tiles[1:]):
|
||||
assert next_tile.F_bm0 >= tile.F_bm0, (
|
||||
"Tiles must be ordered by increasing bm0"
|
||||
)
|
||||
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
problem_ctx = ProblemContext(
|
||||
dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v
|
||||
)
|
||||
kernel_ctx = KernelContext(
|
||||
tile=tile, pipeline=pipeline, mask_impl=mask_impl
|
||||
)
|
||||
rules = factory.get_rules()
|
||||
product = get_product(receipt)
|
||||
|
||||
if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]):
|
||||
continue
|
||||
|
||||
k = create_kernel(factory.arch, problem_ctx, kernel_ctx)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: SageAttnFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.render())
|
||||
|
||||
|
||||
def write_fwd_api(
|
||||
api_pool: SageAttnFwdApiPool,
|
||||
autogen_dir: Path,
|
||||
) -> None:
|
||||
content = "".join(
|
||||
[
|
||||
SAGEATTN_FWD_API_HEADER,
|
||||
api_pool.render("sageattn_fwd_impl"),
|
||||
SAGEATTN_FWD_API_FOOTER_TEMPLATE,
|
||||
]
|
||||
)
|
||||
update_file(autogen_dir / SAGEATTN_FWD_API_FILENAME, content)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write(
|
||||
(file_path.parent / GEN_DIR / SAGEATTN_FWD_API_FILENAME).as_posix() + "\n"
|
||||
)
|
||||
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import dataclasses
|
||||
import os.path as path
|
||||
import textwrap
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def indent(code: str, indent: str = " ") -> str:
|
||||
return textwrap.indent(code, indent)
|
||||
|
||||
|
||||
def if_(i: int) -> str:
|
||||
return "if" if i == 0 else "else if"
|
||||
|
||||
|
||||
def check_duplicates_and_paddings(traits, trait):
|
||||
"""Check
|
||||
* if the traits list does not contain a trait with the same parameters;
|
||||
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
|
||||
for example, f, _t_, f, t cannot be before f, _f_, f, t.
|
||||
"""
|
||||
|
||||
fields = [f.name for f in dataclasses.fields(trait)]
|
||||
pad_fields = [f for f in fields if "pad" in f]
|
||||
non_pad_fields = [f for f in fields if "pad" not in f]
|
||||
for prev_trait in traits:
|
||||
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
|
||||
continue
|
||||
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
|
||||
raise Exception(f"Duplicate found {trait}")
|
||||
# Check if the previous kernel can be incorrectly used before the current one
|
||||
# for example, f, _t_, f, t cannot be before f, _f_, f, t
|
||||
is_prev_more_restrictive = False
|
||||
is_curr_more_restrictive = False
|
||||
for f in pad_fields:
|
||||
prev_pad = getattr(prev_trait, f)
|
||||
pad = getattr(trait, f)
|
||||
if isinstance(prev_pad, str):
|
||||
prev_pad = 1000000 if prev_pad == "f" else 1
|
||||
pad = 1000000 if pad == "f" else 1
|
||||
elif isinstance(prev_pad, int):
|
||||
prev_pad = 1000000 if prev_pad == 0 else prev_pad
|
||||
pad = 1000000 if pad == 0 else pad
|
||||
else:
|
||||
assert False
|
||||
if prev_pad < pad:
|
||||
is_prev_more_restrictive = True
|
||||
elif prev_pad > pad:
|
||||
is_curr_more_restrictive = True
|
||||
if is_prev_more_restrictive and not is_curr_more_restrictive:
|
||||
raise Exception(
|
||||
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
|
||||
)
|
||||
Reference in New Issue
Block a user