[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)

[CK_TILE] Add SageAttention v2 forward kernel with
 multi-granularity quantization (#6574)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

Add a CK_TILE forward kernel implementing [SageAttention
v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that
applies multi-granularity quantization to Q/K/V before computing
attention, trading minimal accuracy loss for higher throughput on
low-precision hardware.

### Quantization design

| Tensor | Supported data types | Scale granularity options |
|--------|---------------------|--------------------------|
| Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(32 tokens), per-thread (4 tokens) |
| K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(64 tokens), per-thread (16 tokens) |
| V | fp8 | per-channel (always) |
| O | bf16 | — |

Three precision combinations are supported: `fp8/bf16` (QKV fp8, O
bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK
int4, V fp8, O bf16).

### Architecture support

- **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set
- **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for
fp8-family dtypes)

### Implementation

- Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC`
(async copy)
- Masking support: no mask, causal (top-left / bottom-right), and
generic windowed
- Batch and group (variable-length) modes
- Head dimension: d=128, d_v=128
- Python codegen under `example/ck_tile/49_sageattention/codegen/`
generates kernel instances per target/dtype/tile combination
- Smoke tests included via `tile_example_sageattn_fwd`

### Test commands

\`\`\`bash
# fp8 QKV
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=fp8bf16 -qscale=3 -init=3

# int8 QK, fp8 V
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=i8fp8bf16 -qscale=3 -init=3
\`\`\`

\`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
ltqin
2026-04-30 18:33:36 +00:00
committed by assistant-librarian[bot]
parent e8d64ad5c6
commit de0a61e5c2
30 changed files with 7809 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 arch is supported
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
if(NOT INST_TARGETS)
message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
# ====================================================================
# SageAttention codegen - only FWD API, minimal instances
# ====================================================================
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
)
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG)
# Only generate FWD API, only supported head dimension (128)
# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py
set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${SAGEATTN_TARGETS_ARG}
--api fwd
--optdim 128
)
# Generate list of kernels to build
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS)
# Generate the kernel instance files
add_custom_command(
OUTPUT ${SAGEATTN_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate SageAttention FWD kernels"
VERBATIM
)
# Build the kernel instances library
add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS})
target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# Compile options for kernel instances
set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
if(CK_USE_OCP_FP8)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS})
set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON)
# ====================================================================
# SageAttention FWD Example
# ====================================================================
set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd")
message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}")
add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp)
target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# Link with our own minimal instances library (INDEPENDENT from FMHA!)
target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances)
set(SAGEATTN_FWD_COMPILE_OPTIONS)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
if(CK_USE_OCP_FP8)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS})
set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

View File

@@ -0,0 +1,2 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,42 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from dataclasses import dataclass, field
from typing import Any, List, Callable
@dataclass(frozen=True)
class ArchTrait:
name: str
preprocessor_check: str = field(default=None)
device_name_check: str = field(default=None)
tag: str = field(default=None)
filename_suffix: str = field(default=None)
def __post_init__(self):
if self.preprocessor_check is None:
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
if self.device_name_check is None:
object.__setattr__(
self,
"device_name_check",
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
)
if self.tag is None:
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
if self.filename_suffix is None:
object.__setattr__(self, "filename_suffix", f"_{self.name}")
def get_factories_for_targets(
targets: List[str], get_factory: Callable[[str], Any]
) -> List[Any]:
factories = dict()
for target in targets:
factory = get_factory(target)
factories[factory.arch.name] = factory
# Place more specific architectures first
factories = sorted(
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
)
return factories

View File

@@ -0,0 +1,4 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder

View File

@@ -0,0 +1,103 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16": "SageAttentionFwdFp16",
"bf16": "SageAttentionFwdBf16",
"fp8bf16": "SageAttentionFwdFp8Bf16",
"i8fp8bf16": "SageAttentionFwdI8Fp8Bf16",
"i4fp8bf16": "SageAttentionFwdI4Fp8Bf16",
}
_MASK_SIMPLIFIED_MAP = {
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no": "SageAttnMasks::NoMask",
"causal": "SageAttnMasks::CausalMask",
"generic": "SageAttnMasks::GenericMask",
}
def get_mask_map(mask_impl: str):
if mask_impl == "generic":
return _MASK_MAP
elif mask_impl == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_impl(mask: str) -> str:
return "simplified" if mask.startswith("s_") else "generic"
def get_mask_cpp_type(mask: str) -> str:
return get_mask_map(get_mask_impl(mask))[mask]
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
def get_mask_cpp_check_expr(mask: str) -> str:
return get_mask_check_map(get_mask_impl(mask))[mask]
QSCALE_MAP = {
"no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE",
"perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP",
"perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"perwarp": "quant_scale_enum::perwarp",
"perthread": "quant_scale_enum::perthread",
}
MODE_MAP = {"batch": "false", "group": "true"}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr": "ck_tile::BlockSageAttentionPipelineQRKSVS",
"qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS",
"qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -0,0 +1,2 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,992 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
import fnmatch
import itertools
import os
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, ClassVar, Iterable, List, Optional, Tuple
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
LAYOUT_MAP,
BOOL_MAP,
PIPELINE_MAP,
PIPELINE_ENUM_MAP,
MODE_MAP,
FWD_DTYPE_MAP,
get_mask_map,
get_mask_cpp_type,
get_mask_cpp_check_expr,
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
DTYPE_BITS = {
"fp16": 16,
"bf16": 16,
"fp8bf16": 8,
"i8fp8bf16": 8,
"i4fp8bf16": 4,
}
K0_MAX_SUBMAX_MAP = {
32: 32,
48: 48,
64: 64,
80: 96,
96: 128,
128: 128,
192: 192,
256: 256,
}
SAGEATTN_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
#include "sageattn_fwd.hpp"
"""
SAGEATTN_FWD_KERNEL_BODY_TEMPLATE = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using sageattn_dtype = {F_dtype};
using sageattn_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using sageattn_shape = ck_tile::TileSageAttnShape<sageattn_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using sageattn_traits = ck_tile::TileSageAttnTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_qscale},
{F_occupancy},
{F_skip}>;
using sageattn_variant = ck_tile::ComposedAttention<false * ck_tile::LOGITS_SOFT_CAP, true>;
using sageattn_mask_type = {F_mask};
using sageattn_pipeline_problem = ck_tile::BlockSageAttnPipelineProblem<
typename SageAttentionFwdTypeConfig<sageattn_dtype>::QDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::KDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::VDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SMPLComputeDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::PDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
sageattn_shape,
{F_mode},
sageattn_variant,
sageattn_mask_type,
sageattn_traits>;
using sageattn_pipeline = {F_pipeline}<
sageattn_pipeline_problem>;
using sageattn_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
{F_spad}, {F_dvpad}>>;
using sageattn_kernel = {F_kernel}<sageattn_pipeline, sageattn_epilogue>;
using trait = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, sageattn_mask_type, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
template<>
float sageattn_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, sageattn_fwd_args a)
{{
using k_ = sageattn_kernel;
if(s.log_level_ > 0)
std::cout << ", {F_kname}" << std::flush;
auto [kargs, grids] = {F_kargs_creator}<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
SAGEATTN_FWD_API_FILENAME = "sageattn_fwd_api.cpp"
SAGEATTN_FWD_API_HEADER = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include <cstdio>
#include <hip/hip_runtime.h>
#include "sageattn_fwd.hpp"
namespace {
bool get_num_cus(unsigned& num_cus) {
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {
fprintf(stderr, "failed to get device");
return false;
}
hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {
fprintf(stderr, "failed to get device properties");
return false;
}
num_cus = props.multiProcessorCount;
return true;
}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}
} // namespace
"""
SAGEATTN_FWD_API_FUNC_TEMPLATE = """
namespace {{
float {F_func_name}([[maybe_unused]] sageattn_fwd_traits t, [[maybe_unused]] sageattn_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if(!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
}} // namespace
"""
SAGEATTN_FWD_API_FOOTER_TEMPLATE = """
// Public API entry point - unified for SageAttention
float sageattn_fwd(sageattn_fwd_traits traits, sageattn_fwd_args args, const ck_tile::stream_config& config) {
return sageattn_fwd_impl(traits, args, config);
}
"""
SAGEATTN_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
{F_dtype_case}
}}
"""
SAGEATTN_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
{F_hdim_case}
}}
"""
SAGEATTN_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
SAGEATTN_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return sageattn_fwd_<trait_, {F_arch.tag}>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class SageAttnFwdApiTrait:
arch: ArchTrait
pipeline_tag: str
# sync with sageattn_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
mask: str
qscale: str #
spad: str
skpad: str
dpad: str
dvpad: str
skip: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.mask}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
return "true"
elif self.pipeline_tag in ["qr", "qs"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_q % {self.bm0} == 0"
else:
assert False
def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0 or self.bm0 == 64:
return "true/*fall back to largest tile*/"
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
else:
assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == "t":
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {bk0submax} == 0"
else:
assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == "t":
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
# F_dvpad="f": Causal mask requires hdim_v <= kN1 (num_tile_n1 == 1 for tile reversal)
# Non-causal requires hdim_v % kN1 == 0 (epilogue writes full tiles)
if self.mask == "causal":
return f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v <= {self.bn1})"
else:
return (
f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v % {self.bn1} == 0)"
)
else:
assert False
@dataclass
class SageAttnFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_qscale: str # no/pertensor/blockscale/perwarp/perthread
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
if self.F_skip == "t":
n += "_skip"
else:
n += "_nskip"
if self.F_qscale != "no":
n += f"_{self.F_qscale}"
else:
n += "_nqscale"
return n
class SageAttnFwdApiPool:
def __init__(self):
self.pool = OrderedDict()
def register_traits(self, trait: SageAttnFwdApiTrait) -> None:
hdim = trait.hdim, trait.bn1
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
def get_num_traits(
self, filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None
) -> int:
if filter_fn is None:
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
return True
filter_fn = accept_all
return sum(
sum(1 for trait in pool_by_hdim if filter_fn(trait))
for pool_by_arch in self.pool.values()
for pool_by_dtype in pool_by_arch.values()
for pool_by_hdim in pool_by_dtype.values()
)
def render(
self,
func_name,
filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None,
) -> str:
if filter_fn is None:
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
return True
filter_fn = accept_all
def has_traits(node) -> bool:
"""Recursively traverse nested OrderedDicts and lists to determine if any SageAttnFwdApiTrait satisfies filter_fn()."""
if isinstance(node, list):
return any(filter_fn(elem) for elem in node)
elif isinstance(node, OrderedDict):
return any(has_traits(val) for val in node.values())
return False
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(
item for item in self.pool.items() if has_traits(item[1])
):
per_dtypes = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(
item for item in pool_by_arch.items() if has_traits(item[1])
):
per_hdim_case = str()
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
item for item in pool_by_dtype.items() if has_traits(item[1])
):
max_bm0 = max(
(t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0
)
inners = str()
for i_trait, trait in enumerate(
[trait for trait in pool_by_hdim if filter_fn(trait)]
):
inners += SAGEATTN_FWD_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_mask=get_mask_cpp_type(trait.mask),
F_mask_check=get_mask_cpp_check_expr(trait.mask),
F_skip=BOOL_MAP[trait.skip],
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
F_qscale=QSCALE_MAP[trait.qscale],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune(max_bm0),
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
per_hdim_case += SAGEATTN_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim_v,
F_inner_dispatch=indent(inners),
)
per_dtypes += SAGEATTN_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
per_arch += SAGEATTN_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
return SAGEATTN_FWD_API_FUNC_TEMPLATE.format(
F_func_name=func_name, F_dispatch=indent(per_arch)
)
@dataclass
class SageAttnFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class SageAttnFwdKernel:
F_arch: ArchTrait
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: SageAttnFwdTileSize
F_pipeline: SageAttnFwdPipeline
_KERNEL_HEADER: ClassVar[str] = SAGEATTN_FWD_KERNEL_HEADER
_KERNEL_BODY_TEMPLATE: ClassVar[str] = SAGEATTN_FWD_KERNEL_BODY_TEMPLATE
@classmethod
def _get_cpp_kernel_class_name(cls, pipeline_tag):
return "ck_tile::SageAttnFwdKernel"
@classmethod
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
return "sageattn_fwd_create_kargs_and_grids"
def render(self) -> str:
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
F_kname=self.name,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_cpp_type(self.F_pipeline.F_mask),
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"sageattn_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
def api_trait(self) -> SageAttnFwdApiTrait:
return SageAttnFwdApiTrait(
arch=self.F_arch,
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
qscale=self.F_pipeline.F_qscale,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
@dataclass
class ProblemContext:
dtype: str
mode: str
hdim: int
hdim_v: int
@dataclass
class KernelContext:
tile: SageAttnFwdTileSize
pipeline: SageAttnFwdPipeline
mask_impl: str
CompatibilityRule = Callable[[ProblemContext, KernelContext], bool]
def is_compatible(
problem_ctx: ProblemContext,
kernel_ctx: KernelContext,
rules: Iterable[CompatibilityRule],
) -> bool:
return all(rule(problem_ctx, kernel_ctx) for rule in rules)
def create_kernel(
arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> SageAttnFwdKernel:
return SageAttnFwdKernel(
F_arch=arch,
F_dtype=problem_ctx.dtype,
F_mode=problem_ctx.mode,
F_hdim=problem_ctx.hdim,
F_tile=kernel_ctx.tile,
F_pipeline=kernel_ctx.pipeline,
)
class CompatibilityRuleFactory:
@staticmethod
def get_rules() -> List[CompatibilityRule]:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
if problem_ctx.mode == "group":
if (
kernel_ctx.pipeline.F_spad != "t"
or kernel_ctx.pipeline.F_skpad != "t"
):
return False
return True
return [check_mode]
class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
@classmethod
def get_rules(cls) -> List[CompatibilityRule]:
rules = CompatibilityRuleFactory.get_rules()
return rules
class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
pass
class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
arch = ArchTrait(
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
)
# Note: fp16 is not supported by SageAttention (only bf16 + fp8/int quantization)
_DT_BF16 = ("bf16",)
_DT_FP8BF16 = ("fp8bf16",)
_DT_I8FP8BF16 = ("i8fp8bf16",)
_DT_I4FP8BF16 = ("i4fp8bf16",)
@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_BF16 + cls._DT_FP8BF16 + cls._DT_I8FP8BF16 + cls._DT_I4FP8BF16
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_BF16:
return {
(128, 128) : [SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# gfx9 (non-gfx950): only F_bn0=64; F_bn0=128 variant is gfx950-only (see Gfx950).
return {
(128, 128): [
SageAttnFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
],
}
else:
raise ValueError(f"unsupported dtype={dtype}")
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@classmethod
def get_pipelines(
cls, dtype, hdim, hdim_v, receipt, mask_impl
) -> List[SageAttnFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let "t" padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in cls._DT_BF16:
qscale = "no"
skip = "f" # skip: only false
for mask, vlayout in itertools.product(
get_mask_map(mask_impl).keys(),
["row", "col"],
):
if hdim == 256 and hdim_v == 256:
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "f", "f", "f", "f", qscale, mask, skip)) # fmt: skip
# the below two is used for hdim vectorize load
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
else:
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
elif (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# no need lse kernels
skip = "f" # skip: only false
for mask, qscale, vlayout in itertools.product(
get_mask_map(mask_impl).keys(),
["no", "pertensor", "blockscale", "perwarp", "perthread"],
["row", "col"], # Support both row and col major layouts
):
if dtype in cls._DT_I4FP8BF16:
# int4 only uses sync pipeline (qr), pad_d="f" because packed types
# require alignment >= PackedSize which conflicts with kPadHeadDimQ=true
# forcing alignment to 1. Safe since hdim always matches tile size.
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
elif hdim == 64:
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
else:
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
# Packed types (int4) cannot use head-dim padding: the tile_window infrastructure
# forces alignment=1 when padding is enabled, but packed types need alignment >= PackedSize.
if dtype in cls._DT_I4FP8BF16:
for p in pipelines:
assert p.F_dpad == "f", (
f"int4 dtype '{dtype}' requires pad_d=false, got '{p.F_dpad}'"
)
assert p.F_dvpad == "f", (
f"int4 dtype '{dtype}' requires pad_dv=false, got '{p.F_dvpad}'"
)
return pipelines
class KernelComponentFactoryGfx950(
KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950
):
arch = ArchTrait("gfx950")
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# gfx950 fp8-family: F_bn0=128 tile only (gfx9 uses F_bn0=64 only).
return {
(128, 128): [
SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
],
}
return super().get_hdim_tile_size_dict(dtype)
class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9):
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
if dtype in cls._DT_BF16:
if (128, 128) in result.keys():
result[(128, 128)].insert(0, SageAttnFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result
def get_factory(target: str):
if os.environ.get("CK_TILE_SAGEATTN_FWD_CUSTOM_FACTORY", "0") == "1":
return CustomFactory
# Place more specific architectures first
if target.startswith("gfx950"):
return KernelComponentFactoryGfx950
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
raise Exception(f"Unsupported device target {target}")
@dataclass(frozen=True)
class Product:
name: str
rule: CompatibilityRule
def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
return self.rule(problem_ctx, kernel_ctx)
def get_product(receipt: int) -> Product:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
# bf16 (no quantization) should not have qscale
if problem_ctx.dtype == "bf16":
if kernel_ctx.pipeline.F_qscale != "no":
return False
return True
return Product(name="All tiles", rule=fit)
def get_fwd_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[SageAttnFwdApiPool, List[SageAttnFwdKernel]]:
gen = list()
api_pool = SageAttnFwdApiPool()
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()):
d = factory.get_hdim_tile_size_dict(dtype)
# for hdim_str, mode, mask, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for ((hdim, hdim_v), tiles), mode in itertools.product(
d.items(), MODE_MAP.keys()
):
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
for tile, next_tile in zip(tiles, tiles[1:]):
assert next_tile.F_bm0 >= tile.F_bm0, (
"Tiles must be ordered by increasing bm0"
)
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
problem_ctx = ProblemContext(
dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v
)
kernel_ctx = KernelContext(
tile=tile, pipeline=pipeline, mask_impl=mask_impl
)
rules = factory.get_rules()
product = get_product(receipt)
if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]):
continue
k = create_kernel(factory.arch, problem_ctx, kernel_ctx)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: SageAttnFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.render())
def write_fwd_api(
api_pool: SageAttnFwdApiPool,
autogen_dir: Path,
) -> None:
content = "".join(
[
SAGEATTN_FWD_API_HEADER,
api_pool.render("sageattn_fwd_impl"),
SAGEATTN_FWD_API_FOOTER_TEMPLATE,
]
)
update_file(autogen_dir / SAGEATTN_FWD_API_FILENAME, content)
def write_blobs(
targets: List[str],
output_dir: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
targets: List[str],
file_path: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write(
(file_path.parent / GEN_DIR / SAGEATTN_FWD_API_FILENAME).as_posix() + "\n"
)

View File

@@ -0,0 +1,70 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import dataclasses
import os.path as path
import textwrap
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
def indent(code: str, indent: str = " ") -> str:
return textwrap.indent(code, indent)
def if_(i: int) -> str:
return "if" if i == 0 else "else if"
def check_duplicates_and_paddings(traits, trait):
"""Check
* if the traits list does not contain a trait with the same parameters;
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
for example, f, _t_, f, t cannot be before f, _f_, f, t.
"""
fields = [f.name for f in dataclasses.fields(trait)]
pad_fields = [f for f in fields if "pad" in f]
non_pad_fields = [f for f in fields if "pad" not in f]
for prev_trait in traits:
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
continue
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
raise Exception(f"Duplicate found {trait}")
# Check if the previous kernel can be incorrectly used before the current one
# for example, f, _t_, f, t cannot be before f, _f_, f, t
is_prev_more_restrictive = False
is_curr_more_restrictive = False
for f in pad_fields:
prev_pad = getattr(prev_trait, f)
pad = getattr(trait, f)
if isinstance(prev_pad, str):
prev_pad = 1000000 if prev_pad == "f" else 1
pad = 1000000 if pad == "f" else 1
elif isinstance(prev_pad, int):
prev_pad = 1000000 if prev_pad == 0 else prev_pad
pad = 1000000 if pad == 0 else pad
else:
assert False
if prev_pad < pad:
is_prev_more_restrictive = True
elif prev_pad > pad:
is_curr_more_restrictive = True
if is_prev_more_restrictive and not is_curr_more_restrictive:
raise Exception(
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
)

View File

@@ -0,0 +1,202 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "sageattn_fwd.hpp"
#include "sageattn_fwd_runner.hpp"
#include <string>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_k",
"-1",
"seqlen_k (including new key/value), -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"seqlen_q stride between 2 batches (group-mode optional).\n"
"Provide positive strides per-batch to simulate physical padding on Q.")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed, same as xformer kv_padding,\n"
"must be greater than or equal to s_k")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("qscale",
"n",
"n or 0, no scale\n"
"pt or 1, per-tensor scale\n"
"bs or 2, block scale (Q:128, KV:128)\n"
"pw or 3, per-warp scale (Q:32, KV:64)\n"
"pth or 4, per-thread scale (Q:4, KV:16)\n")
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("prec",
"fp8bf16",
"Primary: fp8bf16, i8fp8bf16, i4fp8bf16. Also bf16 (keep): pipeline validation "
"with qscale=n (no quant); not the quantized Sage product path.")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init",
"uf",
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
"\n uf or 1 - uniform random float\n nf - normalized random float"
"\n tf or 2 - trig float"
"\n tf or 3 - uniform random float, min max is the max of the type\n")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "sageattn_fwd.json", "json file name to dump results")
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataTypeConfig>
auto run(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
float scale_s = arg_parser.get_float("scale_s");
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
std::string qscale_str = arg_parser.get_str("qscale");
std::string mask_str = arg_parser.get_str("mask");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_str("timer") == std::string("gpu")};
auto json = arg_parser.get_int("json") == 1
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
: std::nullopt;
return sageattn_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
hdim_q,
hdim_v,
seqlen_qpads,
seqlen_kpads,
q_eff_lens_per_batch,
kv_eff_lens_per_batch,
i_perm,
o_perm,
scale_s,
is_v_rowmajor,
mask_str,
qscale_str,
init_method,
seed,
do_validation,
stream_config,
json);
}
int main(int argc, char* argv[])
{
try
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "bf16")
{
return run<SageAttentionFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<SageAttentionFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "i8fp8bf16")
{
return run<SageAttentionFwdI8Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "i4fp8bf16")
{
return run<SageAttentionFwdI4Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}
catch(const std::invalid_argument& e)
{
std::cerr << "Invalid argument: " << e.what() << std::endl;
return -1;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return -2;
}
}

View File

@@ -0,0 +1,173 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
from typing import List, Optional
import codegen.ops
from codegen.cmake_config import GEN_DIR
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
# Strip "sageattn_" so module sageattn_fwd registers as CLI key "fwd".
unwanted_prefix = "sageattn_"
handlers = dict(
[
(
(
op.__name__[len(unwanted_prefix) :]
if op.__name__.startswith(unwanted_prefix)
else op.__name__
),
(op.list_blobs, op.write_blobs),
)
for op in ops
]
)
assert 0 < len(handlers)
def write_blobs(
targets: List[str],
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(
targets: List[str],
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="Generate SageAttention CK_tile kernel/API blobs.",
)
parser.add_argument(
"--targets",
default="gfx9,gfx950",
required=False,
help="list of GPU targets, separated by comma.",
)
parser.add_argument(
"-a",
"--api",
default="fwd",
required=False,
help="Codegen API key(s), comma-separated (e.g. fwd -> module codegen.ops.sageattn_fwd).",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory",
)
parser.add_argument(
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default="",
required=False,
help="filter out kernels that need to generate, using fnmatch module",
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic",
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="Codegen receipt index. SageAttention forward currently uses receipt 0 only; "
"the value is passed through to ops (see get_product in sageattn_fwd.py).",
)
parser.add_argument(
"--optdim",
default="-1",
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice. "
"e.g. --optdim=32,64,128,256",
)
args = parser.parse_args()
targets = args.targets.split(",")
api_list = args.api.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
if args.list_blobs is not None:
list_blobs(
targets,
args.list_blobs,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)
else:
write_blobs(
targets,
args.output_dir,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)

View File

@@ -0,0 +1,169 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
mask_top_left,
mask_bottom_right,
window_generic,
};
struct mask_info
{
mask_enum type;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::mask_top_left)
os << "t(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "b(" << left << ":" << right << ")";
else
{
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
{
ck_tile::index_t x_total = seqlen_k;
ck_tile::index_t y_total = seqlen_q;
mask_info tmp;
tmp.seqlen_q = seqlen_q;
tmp.seqlen_k = seqlen_k;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
if(t == "xt" || t == "xb")
{
// xformer style sliding window attn from top-left
ck_tile::index_t window_size = std::stoi(v);
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, 0, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = left_size;
tmp.right = right_size;
}
else if(t == "t" || t == "b" || t == "g")
{
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
{
throw std::invalid_argument("invalid mask value: " + str);
}
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, 0, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, 0, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.type = mask_enum::window_generic;
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
}
else
{
throw std::invalid_argument("invalid mask value: " + str);
}
}
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
}
else if(str == "1" || str == "t")
{
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(str == "2" || str == "b")
{
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
}
else
{
throw std::invalid_argument("invalid mask value: " + str);
}
return tmp;
}
std::size_t get_unmaskarea() const
{
if(type == mask_enum::no_mask)
return static_cast<std::size_t>(seqlen_q) * seqlen_k;
std::size_t area = 0;
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
{
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
if(x_end > x_start)
{
area += (x_end - x_start);
}
}
return area;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
};

View File

@@ -0,0 +1,74 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
// keep sync with BlockSageAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale = 2,
perwarp = 3,
perthread = 4,
};
struct quant_scale_info
{
quant_scale_enum type;
void serialize(std::ostream& os) const
{
if(type == quant_scale_enum::no_scale)
os << "n";
else if(type == quant_scale_enum::pertensor)
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::perwarp)
os << "pw";
else if(type == quant_scale_enum::perthread)
os << "pth";
}
static quant_scale_info decode(std::string str)
{
quant_scale_info info{quant_scale_enum::no_scale};
if(str == "n" || str == "0")
{
info.type = quant_scale_enum::no_scale;
}
else if(str == "pt" || str == "1")
{
info.type = quant_scale_enum::pertensor;
}
else if(str == "bs" || str == "2")
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "pw" || str == "3")
{
info.type = quant_scale_enum::perwarp;
}
else if(str == "pth" || str == "4")
{
info.type = quant_scale_enum::perthread;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
{
qsi.serialize(os);
return os;
}
};

View File

@@ -0,0 +1,384 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/sageattn.hpp"
#include "mask.hpp"
#include "quant.hpp"
#include <type_traits>
#include <utility>
#include <variant>
// SageAttention data type configs (must match codegen FWD_DTYPE_MAP + SageAttentionFwdTypeConfig)
struct SageAttentionFwdFp16
{
};
struct SageAttentionFwdBf16
{
};
struct SageAttentionFwdFp8Bf16
{
};
struct SageAttentionFwdI8Fp8Bf16
{
};
struct SageAttentionFwdI4Fp8Bf16
{
};
template <typename DataType>
struct SageAttentionFwdTypeConfig;
// fp16/bf16 are not Sage product dtypes; bf16 is intentionally kept in tile_example_sageattn_fwd
// for pipeline validation with qscale=n (no quant).
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using ScaleType = float; // scale type for quantized inputs
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using ScaleType = float; // scale type for quantized inputs
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp8Bf16>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using ScaleType = float; // scale type for quantized inputs
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdI8Fp8Bf16>
{
using QDataType = ck_tile::int8_t;
using KDataType = ck_tile::int8_t;
using VDataType = ck_tile::fp8_t;
using ScaleType = float; // scale type for Q and K
using SaccDataType = float; // Keep as float for softmax computation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // P in FP8 for 2nd gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdI4Fp8Bf16>
{
using QDataType = ck_tile::pk_int4_t;
using KDataType = ck_tile::pk_int4_t;
using VDataType = ck_tile::fp8_t;
using ScaleType = float;
using SaccDataType = float;
using SMPLComputeDataType = float;
using PDataType = ck_tile::fp8_t;
using OaccDataType = float;
using ODataType = ck_tile::bf16_t;
};
struct SageAttnMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct sageattn_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* q_descale_ptr;
const void* k_descale_ptr;
const void* v_descale_ptr;
void* o_ptr;
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
// BLOCKSCALE parameters
ck_tile::index_t nhead_stride_q_descale = 0;
ck_tile::index_t nhead_stride_k_descale = 0;
ck_tile::index_t nhead_stride_v_descale = 0;
ck_tile::index_t batch_stride_q_descale = 0;
ck_tile::index_t batch_stride_k_descale = 0;
ck_tile::index_t batch_stride_v_descale = 0;
ck_tile::index_t block_scale_size_q = 0;
ck_tile::index_t block_scale_size_k = 0;
const void* block_scale_seqstart_q_ptr = nullptr;
const void* block_scale_seqstart_k_ptr = nullptr;
};
template <typename SageAttnKernel>
auto sageattn_fwd_create_kargs_and_grids(sageattn_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(SageAttnKernel::kIsGroupMode)
{
return SageAttnKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_v_descale,
args.block_scale_size_q,
args.block_scale_size_k,
args.block_scale_seqstart_q_ptr,
args.block_scale_seqstart_k_ptr,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.min_seqlen_q,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{ // create batch mode kernel arguments
return SageAttnKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_o,
args.batch_stride_q_descale,
args.batch_stride_k_descale,
args.batch_stride_v_descale,
args.block_scale_size_q,
args.block_scale_size_k,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
}();
if constexpr(SageAttnKernel::kIsGroupMode)
{
dim3 grids = SageAttnKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids = SageAttnKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
}
// this is used to pattern-match internal kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockSageAttnPipelineEnum SageAttnPipelineEnum_,
typename AttnMask_,
ck_tile::BlockSageAttentionQuantScaleEnum QScaleEnum_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
struct sageattn_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto SageAttnPipelineEnum = SageAttnPipelineEnum_;
using AttnMask = ck_tile::remove_cvref_t<AttnMask_>;
static constexpr auto QScaleEnum = QScaleEnum_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_, typename Arch = void>
float sageattn_fwd_(const ck_tile::stream_config&, sageattn_fwd_args);
// This is the public API, will be generated by script
struct sageattn_fwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
quant_scale_enum qscale_type;
bool skip_min_seqlen_q = false;
// TODO: padding check is inside this api
};
float sageattn_fwd(sageattn_fwd_traits, sageattn_fwd_args, const ck_tile::stream_config&);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,162 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
#
# SageAttention forward smoke tests - structure mirrors
# example/ck_tile/01_fmha/script/smoke_test_fwd.sh
#
# Run from the ComposableKernel *build* directory (after ninja), same as FMHA:
# cd build && ninja tile_example_sageattn_fwd
# bash ../example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
#
# Optional: VERBOSE=1 enables bash -x. CURR_FAILS_FILE / KNOWN_FAILS_FILE override fail logs.
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXE_NAME=tile_example_sageattn_fwd
EXE="$(find . -name "$EXE_NAME" -type f 2>/dev/null | head -n 1)"
KNAME=1
GPU_arch=${GPU_arch:-}
if [ -z "$GPU_arch" ]; then
GPU_arch=$(rocminfo 2>/dev/null | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}' || echo "unknown")
fi
export CK_WARMUP=0
export CK_REPEAT=1
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"sageattn_fwd_fails_${GPU_arch}.txt"}
rm -f "$CURR_FAILS_FILE"
touch "$CURR_FAILS_FILE"
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/sageattn_fwd_known_fails_${GPU_arch}.txt"}
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then
echo "ERROR: $EXE_NAME not found under cwd ($(pwd)). Build with: ninja $EXE_NAME" >&2
exit 1
fi
run_exe() {
set +e
$EXE "$@"
local ret=$?
if [ $ret -ne 0 ]; then
echo "$EXE_NAME $*" >>"$CURR_FAILS_FILE"
fi
set -e
}
# Core FP8xBF16 cases aligned with FMHA smoke_test_fwd.sh (lines 80-87): batch/group shapes,
# masks, GQA, short seqlen, k-only pad. Sweeps blockscale (2) vs per-warp (3) and layouts.
run_fp8bf16_smoke() {
local qscale
local perm
for qscale in 2 3; do
for perm in 0 1; do
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=2 -h_k=1 -d=128 -d_v=128 -s=55 -s_k=256 \
-mask=1
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=3 -d=128 -s=100 -s_k=51 -mask=0
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=99 -s_k=256 \
-mask=1
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1024 -s_k=256 \
-mask=2
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=3 -s_k=99 -mask=2
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=3 -h=2 -h_k=1 -d=128 -s=200 -s_k=520 \
-mask=t:128,30
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -s=99 -s_k=32 -mask=b:4,35
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=33 -s_k=0 -mask=2
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 \
-s_kpad=32 -mask=2
done
done
}
# Extra FP8: explicit causal string, xformer window, per-tensor / per-thread quant, V col-major.
run_fp8bf16_extras() {
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=t:-1,0
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=1 -operm=1 -vlayout=c -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=2 -h=4 -d=128 -s=256 -s_k=256 -mask=t
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=256 -s_k=256 -mask=xt:64
run_exe -prec=fp8bf16 -init=3 -qscale=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=0
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=64 -s_k=64 -mask=0
}
# Group mode + physical padding (same intent as FMHA run_padding_smoke_tests, Sage-only flags).
run_group_and_padding_smoke() {
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
# group + PERTHREAD: block_scale_seqstart_* must be allocated (same as bs/pw)
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=1 -b=4 -h=8 -h_k=8 -d=128 -s=1024,768,512,256 -s_k=1024,768,512,256 \
-mask=0 -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=0 \
-q_eff_lens=960,512,384,256 -kv_eff_lens=960,512,384,256
}
# BF16 (no quant): pipeline sanity only; not a shipped Sage mode (see example --help prec).
run_bf16_pipeline_smoke() {
run_exe -prec=bf16 -init=1 -qscale=n -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
run_exe -prec=bf16 -init=1 -qscale=n -iperm=1 -operm=1 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=4 -h_k=1 -d=128 -s=256 -s_k=128 -mask=t:32,32
}
# int8 / int4 x fp8xbf16 (hdim divisible by 8 for int4)
run_int_quant_smoke() {
run_exe -prec=i8fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
run_exe -prec=i4fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=t
}
if [ "${VERBOSE:-0}" = 1 ]; then
set -x
fi
run_fp8bf16_smoke
run_fp8bf16_extras
run_group_and_padding_smoke
run_bf16_pipeline_smoke
run_int_quant_smoke
set +x
new_fails_count=0
known_fails_count=0
if [ -f "$KNOWN_FAILS_FILE" ]; then
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
while IFS= read -r line; do
if grep -Fxq "$line" "$KNOWN_FAILS_FILE"; then
echo "Known fail: $line"
known_fails_count=$((known_fails_count + 1))
else
echo "New fail: $line"
new_fails_count=$((new_fails_count + 1))
fi
done <"$CURR_FAILS_FILE"
else
new_fails_count=$(wc -l <"$CURR_FAILS_FILE")
echo "No known fails file, all fails ($new_fails_count) are new:"
if [ "$new_fails_count" -gt 0 ]; then
cat "$CURR_FAILS_FILE"
fi
fi
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
exit $((new_fails_count != 0))

View File

@@ -0,0 +1,254 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ck_tile/core/container/span.hpp"
enum class mode_enum
{
batch = 0,
group
};
inline std::ostream& operator<<(std::ostream& stream, mode_enum mode)
{
return stream << (mode == mode_enum::batch ? "batch" : "group");
}
template <typename T>
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
inline std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
{
std::vector<int32_t> seqstarts = {0};
for(int32_t seqlen : seqlens)
{
seqstarts.push_back(seqstarts.back() + seqlen);
}
assert(seqstarts.size() == seqlens.size() + 1);
return seqstarts;
}
template <typename RandomEngine>
std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min, // if not negative, clamp min
int32_t seqlen_max, // if not negative, clamp max
RandomEngine& random_engine)
{
assert(0 < count);
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(mode == mode_enum::group && 1 < count)
{
using size_type = std::vector<int32_t>::size_type;
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
auto next_step = std::bind(step_dist, std::ref(random_engine));
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{
const size_type to_decrease = next_idx();
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if(seqlens[to_decrease] == seqlen_min)
{
continue;
}
const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlens[to_increase] >= seqlen_max)
{
continue;
}
--seqlens[to_decrease];
++seqlens[to_increase];
}
}
return seqlens;
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int, typename RandomEngine>
auto randint(Int low,
Int high,
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
{
std::uniform_int_distribution<Int> dist(low, high);
return dist(random_engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator, typename RandomEngine>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
{
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(random_engine); });
}
/*
* generate missing values in *_val randomly when the number of values is smaller than batch
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& q_pad_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
RandomEngine& random_engine)
{
if(mode == mode_enum::batch)
{
ck_tile::index_t q = q_val[0];
ck_tile::index_t k = k_val[0];
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = [&] {
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && need_append_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
random_engine);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
else
{
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
std::vector<ck_tile::index_t> s_qpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
ck_tile::index_t q = q_val[idx];
ck_tile::index_t k =
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
ck_tile::index_t kp =
k_pad_val.empty()
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
ck_tile::index_t qp =
q_pad_val.empty()
? -1
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
s_qpad.push_back(qp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
}
if(idx < batch)
{
auto rem_q =
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
auto rem_k = generate_seqlens(
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
}
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
}
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
RandomEngine& random_engine)
{
std::iota(first, last, value);
std::shuffle(first, last, random_engine);
}

View File

@@ -31,6 +31,7 @@ add_subdirectory(38_block_scale_gemm)
add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)
add_subdirectory(42_mx_gemm)
add_subdirectory(49_sageattention)
add_subdirectory(50_sparse_attn)
add_subdirectory(51_tile_distr_enc_reg_map)
if(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS)