[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574)

## Summary

Add a CK_TILE forward kernel implementing [SageAttention
v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that
applies multi-granularity quantization to Q/K/V before computing
attention, trading minimal accuracy loss for higher throughput on
low-precision hardware.

### Quantization design

| Tensor | Supported data types | Scale granularity options |
|--------|---------------------|--------------------------|
| Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(32 tokens), per-thread (4 tokens) |
| K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(64 tokens), per-thread (16 tokens) |
| V | fp8 | per-channel (always) |
| O | bf16 | — |

Three precision combinations are supported: `fp8/bf16` (QKV fp8, O
bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK
int4, V fp8, O bf16).

### Architecture support

- **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set
- **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for
fp8-family dtypes)

### Implementation

- Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC`
(async copy)
- Masking support: no mask, causal (top-left / bottom-right), and
generic windowed
- Batch and group (variable-length) modes
- Head dimension: d=128, d_v=128
- Python codegen under `example/ck_tile/49_sageattention/codegen/`
generates kernel instances per target/dtype/tile combination
- Smoke tests included via `tile_example_sageattn_fwd`

### Test commands

\`\`\`bash
# fp8 QKV
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=fp8bf16 -qscale=3 -init=3

# int8 QK, fp8 V
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=i8fp8bf16 -qscale=3 -init=3
\`\`\`

\`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
ltqin
2026-05-01 02:32:23 +08:00
committed by GitHub
parent 2a479f7411
commit c1bf3f6972
30 changed files with 7809 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 arch is supported
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
if(NOT INST_TARGETS)
message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
# ====================================================================
# SageAttention codegen - only FWD API, minimal instances
# ====================================================================
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
)
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG)
# Only generate FWD API, only supported head dimension (128)
# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py
set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${SAGEATTN_TARGETS_ARG}
--api fwd
--optdim 128
)
# Generate list of kernels to build
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS)
# Generate the kernel instance files
add_custom_command(
OUTPUT ${SAGEATTN_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate SageAttention FWD kernels"
VERBATIM
)
# Build the kernel instances library
add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS})
target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# Compile options for kernel instances
set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
if(CK_USE_OCP_FP8)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS})
set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON)
# ====================================================================
# SageAttention FWD Example
# ====================================================================
set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd")
message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}")
add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp)
target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# Link with our own minimal instances library (INDEPENDENT from FMHA!)
target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances)
set(SAGEATTN_FWD_COMPILE_OPTIONS)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
if(CK_USE_OCP_FP8)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS})
set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

View File

@@ -0,0 +1,2 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,42 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from dataclasses import dataclass, field
from typing import Any, List, Callable
@dataclass(frozen=True)
class ArchTrait:
name: str
preprocessor_check: str = field(default=None)
device_name_check: str = field(default=None)
tag: str = field(default=None)
filename_suffix: str = field(default=None)
def __post_init__(self):
if self.preprocessor_check is None:
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
if self.device_name_check is None:
object.__setattr__(
self,
"device_name_check",
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
)
if self.tag is None:
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
if self.filename_suffix is None:
object.__setattr__(self, "filename_suffix", f"_{self.name}")
def get_factories_for_targets(
targets: List[str], get_factory: Callable[[str], Any]
) -> List[Any]:
factories = dict()
for target in targets:
factory = get_factory(target)
factories[factory.arch.name] = factory
# Place more specific architectures first
factories = sorted(
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
)
return factories

View File

@@ -0,0 +1,4 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder

View File

@@ -0,0 +1,103 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16": "SageAttentionFwdFp16",
"bf16": "SageAttentionFwdBf16",
"fp8bf16": "SageAttentionFwdFp8Bf16",
"i8fp8bf16": "SageAttentionFwdI8Fp8Bf16",
"i4fp8bf16": "SageAttentionFwdI4Fp8Bf16",
}
_MASK_SIMPLIFIED_MAP = {
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no": "SageAttnMasks::NoMask",
"causal": "SageAttnMasks::CausalMask",
"generic": "SageAttnMasks::GenericMask",
}
def get_mask_map(mask_impl: str):
if mask_impl == "generic":
return _MASK_MAP
elif mask_impl == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_impl(mask: str) -> str:
return "simplified" if mask.startswith("s_") else "generic"
def get_mask_cpp_type(mask: str) -> str:
return get_mask_map(get_mask_impl(mask))[mask]
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
def get_mask_cpp_check_expr(mask: str) -> str:
return get_mask_check_map(get_mask_impl(mask))[mask]
QSCALE_MAP = {
"no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE",
"perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP",
"perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"perwarp": "quant_scale_enum::perwarp",
"perthread": "quant_scale_enum::perthread",
}
MODE_MAP = {"batch": "false", "group": "true"}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr": "ck_tile::BlockSageAttentionPipelineQRKSVS",
"qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS",
"qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -0,0 +1,2 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,992 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
import fnmatch
import itertools
import os
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, ClassVar, Iterable, List, Optional, Tuple
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
LAYOUT_MAP,
BOOL_MAP,
PIPELINE_MAP,
PIPELINE_ENUM_MAP,
MODE_MAP,
FWD_DTYPE_MAP,
get_mask_map,
get_mask_cpp_type,
get_mask_cpp_check_expr,
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
DTYPE_BITS = {
"fp16": 16,
"bf16": 16,
"fp8bf16": 8,
"i8fp8bf16": 8,
"i4fp8bf16": 4,
}
K0_MAX_SUBMAX_MAP = {
32: 32,
48: 48,
64: 64,
80: 96,
96: 128,
128: 128,
192: 192,
256: 256,
}
SAGEATTN_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
#include "sageattn_fwd.hpp"
"""
SAGEATTN_FWD_KERNEL_BODY_TEMPLATE = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using sageattn_dtype = {F_dtype};
using sageattn_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using sageattn_shape = ck_tile::TileSageAttnShape<sageattn_block_tile,
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using sageattn_traits = ck_tile::TileSageAttnTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_qscale},
{F_occupancy},
{F_skip}>;
using sageattn_variant = ck_tile::ComposedAttention<false * ck_tile::LOGITS_SOFT_CAP, true>;
using sageattn_mask_type = {F_mask};
using sageattn_pipeline_problem = ck_tile::BlockSageAttnPipelineProblem<
typename SageAttentionFwdTypeConfig<sageattn_dtype>::QDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::KDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::VDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SMPLComputeDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::PDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
sageattn_shape,
{F_mode},
sageattn_variant,
sageattn_mask_type,
sageattn_traits>;
using sageattn_pipeline = {F_pipeline}<
sageattn_pipeline_problem>;
using sageattn_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
{F_spad}, {F_dvpad}>>;
using sageattn_kernel = {F_kernel}<sageattn_pipeline, sageattn_epilogue>;
using trait = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, sageattn_mask_type, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
template<>
float sageattn_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, sageattn_fwd_args a)
{{
using k_ = sageattn_kernel;
if(s.log_level_ > 0)
std::cout << ", {F_kname}" << std::flush;
auto [kargs, grids] = {F_kargs_creator}<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
SAGEATTN_FWD_API_FILENAME = "sageattn_fwd_api.cpp"
SAGEATTN_FWD_API_HEADER = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include <cstdio>
#include <hip/hip_runtime.h>
#include "sageattn_fwd.hpp"
namespace {
bool get_num_cus(unsigned& num_cus) {
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {
fprintf(stderr, "failed to get device");
return false;
}
hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {
fprintf(stderr, "failed to get device properties");
return false;
}
num_cus = props.multiProcessorCount;
return true;
}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}
} // namespace
"""
SAGEATTN_FWD_API_FUNC_TEMPLATE = """
namespace {{
float {F_func_name}([[maybe_unused]] sageattn_fwd_traits t, [[maybe_unused]] sageattn_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if(!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
}} // namespace
"""
SAGEATTN_FWD_API_FOOTER_TEMPLATE = """
// Public API entry point - unified for SageAttention
float sageattn_fwd(sageattn_fwd_traits traits, sageattn_fwd_args args, const ck_tile::stream_config& config) {
return sageattn_fwd_impl(traits, args, config);
}
"""
SAGEATTN_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
{F_dtype_case}
}}
"""
SAGEATTN_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
{F_hdim_case}
}}
"""
SAGEATTN_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
SAGEATTN_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return sageattn_fwd_<trait_, {F_arch.tag}>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class SageAttnFwdApiTrait:
arch: ArchTrait
pipeline_tag: str
# sync with sageattn_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
mask: str
qscale: str #
spad: str
skpad: str
dpad: str
dvpad: str
skip: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.mask}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
return "true"
elif self.pipeline_tag in ["qr", "qs"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_q % {self.bm0} == 0"
else:
assert False
def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0 or self.bm0 == 64:
return "true/*fall back to largest tile*/"
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
else:
assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == "t":
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {bk0submax} == 0"
else:
assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == "t":
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
# F_dvpad="f": Causal mask requires hdim_v <= kN1 (num_tile_n1 == 1 for tile reversal)
# Non-causal requires hdim_v % kN1 == 0 (epilogue writes full tiles)
if self.mask == "causal":
return f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v <= {self.bn1})"
else:
return (
f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v % {self.bn1} == 0)"
)
else:
assert False
@dataclass
class SageAttnFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_qscale: str # no/pertensor/blockscale/perwarp/perthread
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
if self.F_skip == "t":
n += "_skip"
else:
n += "_nskip"
if self.F_qscale != "no":
n += f"_{self.F_qscale}"
else:
n += "_nqscale"
return n
class SageAttnFwdApiPool:
def __init__(self):
self.pool = OrderedDict()
def register_traits(self, trait: SageAttnFwdApiTrait) -> None:
hdim = trait.hdim, trait.bn1
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
def get_num_traits(
self, filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None
) -> int:
if filter_fn is None:
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
return True
filter_fn = accept_all
return sum(
sum(1 for trait in pool_by_hdim if filter_fn(trait))
for pool_by_arch in self.pool.values()
for pool_by_dtype in pool_by_arch.values()
for pool_by_hdim in pool_by_dtype.values()
)
def render(
self,
func_name,
filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None,
) -> str:
if filter_fn is None:
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
return True
filter_fn = accept_all
def has_traits(node) -> bool:
"""Recursively traverse nested OrderedDicts and lists to determine if any SageAttnFwdApiTrait satisfies filter_fn()."""
if isinstance(node, list):
return any(filter_fn(elem) for elem in node)
elif isinstance(node, OrderedDict):
return any(has_traits(val) for val in node.values())
return False
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(
item for item in self.pool.items() if has_traits(item[1])
):
per_dtypes = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(
item for item in pool_by_arch.items() if has_traits(item[1])
):
per_hdim_case = str()
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
item for item in pool_by_dtype.items() if has_traits(item[1])
):
max_bm0 = max(
(t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0
)
inners = str()
for i_trait, trait in enumerate(
[trait for trait in pool_by_hdim if filter_fn(trait)]
):
inners += SAGEATTN_FWD_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_mask=get_mask_cpp_type(trait.mask),
F_mask_check=get_mask_cpp_check_expr(trait.mask),
F_skip=BOOL_MAP[trait.skip],
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
F_qscale=QSCALE_MAP[trait.qscale],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune(max_bm0),
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
per_hdim_case += SAGEATTN_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim_v,
F_inner_dispatch=indent(inners),
)
per_dtypes += SAGEATTN_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
per_arch += SAGEATTN_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
return SAGEATTN_FWD_API_FUNC_TEMPLATE.format(
F_func_name=func_name, F_dispatch=indent(per_arch)
)
@dataclass
class SageAttnFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class SageAttnFwdKernel:
F_arch: ArchTrait
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: SageAttnFwdTileSize
F_pipeline: SageAttnFwdPipeline
_KERNEL_HEADER: ClassVar[str] = SAGEATTN_FWD_KERNEL_HEADER
_KERNEL_BODY_TEMPLATE: ClassVar[str] = SAGEATTN_FWD_KERNEL_BODY_TEMPLATE
@classmethod
def _get_cpp_kernel_class_name(cls, pipeline_tag):
return "ck_tile::SageAttnFwdKernel"
@classmethod
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
return "sageattn_fwd_create_kargs_and_grids"
def render(self) -> str:
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
F_kname=self.name,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_cpp_type(self.F_pipeline.F_mask),
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"sageattn_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
def api_trait(self) -> SageAttnFwdApiTrait:
return SageAttnFwdApiTrait(
arch=self.F_arch,
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
qscale=self.F_pipeline.F_qscale,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
@dataclass
class ProblemContext:
dtype: str
mode: str
hdim: int
hdim_v: int
@dataclass
class KernelContext:
tile: SageAttnFwdTileSize
pipeline: SageAttnFwdPipeline
mask_impl: str
CompatibilityRule = Callable[[ProblemContext, KernelContext], bool]
def is_compatible(
problem_ctx: ProblemContext,
kernel_ctx: KernelContext,
rules: Iterable[CompatibilityRule],
) -> bool:
return all(rule(problem_ctx, kernel_ctx) for rule in rules)
def create_kernel(
arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> SageAttnFwdKernel:
return SageAttnFwdKernel(
F_arch=arch,
F_dtype=problem_ctx.dtype,
F_mode=problem_ctx.mode,
F_hdim=problem_ctx.hdim,
F_tile=kernel_ctx.tile,
F_pipeline=kernel_ctx.pipeline,
)
class CompatibilityRuleFactory:
@staticmethod
def get_rules() -> List[CompatibilityRule]:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
if problem_ctx.mode == "group":
if (
kernel_ctx.pipeline.F_spad != "t"
or kernel_ctx.pipeline.F_skpad != "t"
):
return False
return True
return [check_mode]
class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
@classmethod
def get_rules(cls) -> List[CompatibilityRule]:
rules = CompatibilityRuleFactory.get_rules()
return rules
class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
pass
class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
arch = ArchTrait(
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
)
# Note: fp16 is not supported by SageAttention (only bf16 + fp8/int quantization)
_DT_BF16 = ("bf16",)
_DT_FP8BF16 = ("fp8bf16",)
_DT_I8FP8BF16 = ("i8fp8bf16",)
_DT_I4FP8BF16 = ("i4fp8bf16",)
@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_BF16 + cls._DT_FP8BF16 + cls._DT_I8FP8BF16 + cls._DT_I4FP8BF16
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_BF16:
return {
(128, 128) : [SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# gfx9 (non-gfx950): only F_bn0=64; F_bn0=128 variant is gfx950-only (see Gfx950).
return {
(128, 128): [
SageAttnFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
],
}
else:
raise ValueError(f"unsupported dtype={dtype}")
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@classmethod
def get_pipelines(
cls, dtype, hdim, hdim_v, receipt, mask_impl
) -> List[SageAttnFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let "t" padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in cls._DT_BF16:
qscale = "no"
skip = "f" # skip: only false
for mask, vlayout in itertools.product(
get_mask_map(mask_impl).keys(),
["row", "col"],
):
if hdim == 256 and hdim_v == 256:
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "f", "f", "f", "f", qscale, mask, skip)) # fmt: skip
# the below two is used for hdim vectorize load
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
else:
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
elif (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# no need lse kernels
skip = "f" # skip: only false
for mask, qscale, vlayout in itertools.product(
get_mask_map(mask_impl).keys(),
["no", "pertensor", "blockscale", "perwarp", "perthread"],
["row", "col"], # Support both row and col major layouts
):
if dtype in cls._DT_I4FP8BF16:
# int4 only uses sync pipeline (qr), pad_d="f" because packed types
# require alignment >= PackedSize which conflicts with kPadHeadDimQ=true
# forcing alignment to 1. Safe since hdim always matches tile size.
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
elif hdim == 64:
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
else:
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
# Packed types (int4) cannot use head-dim padding: the tile_window infrastructure
# forces alignment=1 when padding is enabled, but packed types need alignment >= PackedSize.
if dtype in cls._DT_I4FP8BF16:
for p in pipelines:
assert p.F_dpad == "f", (
f"int4 dtype '{dtype}' requires pad_d=false, got '{p.F_dpad}'"
)
assert p.F_dvpad == "f", (
f"int4 dtype '{dtype}' requires pad_dv=false, got '{p.F_dvpad}'"
)
return pipelines
class KernelComponentFactoryGfx950(
KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950
):
arch = ArchTrait("gfx950")
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if (
dtype in cls._DT_FP8BF16
or dtype in cls._DT_I8FP8BF16
or dtype in cls._DT_I4FP8BF16
):
# gfx950 fp8-family: F_bn0=128 tile only (gfx9 uses F_bn0=64 only).
return {
(128, 128): [
SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
],
}
return super().get_hdim_tile_size_dict(dtype)
class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9):
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
if dtype in cls._DT_BF16:
if (128, 128) in result.keys():
result[(128, 128)].insert(0, SageAttnFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result
def get_factory(target: str):
if os.environ.get("CK_TILE_SAGEATTN_FWD_CUSTOM_FACTORY", "0") == "1":
return CustomFactory
# Place more specific architectures first
if target.startswith("gfx950"):
return KernelComponentFactoryGfx950
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
raise Exception(f"Unsupported device target {target}")
@dataclass(frozen=True)
class Product:
name: str
rule: CompatibilityRule
def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
return self.rule(problem_ctx, kernel_ctx)
def get_product(receipt: int) -> Product:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
# bf16 (no quantization) should not have qscale
if problem_ctx.dtype == "bf16":
if kernel_ctx.pipeline.F_qscale != "no":
return False
return True
return Product(name="All tiles", rule=fit)
def get_fwd_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[SageAttnFwdApiPool, List[SageAttnFwdKernel]]:
gen = list()
api_pool = SageAttnFwdApiPool()
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()):
d = factory.get_hdim_tile_size_dict(dtype)
# for hdim_str, mode, mask, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for ((hdim, hdim_v), tiles), mode in itertools.product(
d.items(), MODE_MAP.keys()
):
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
for tile, next_tile in zip(tiles, tiles[1:]):
assert next_tile.F_bm0 >= tile.F_bm0, (
"Tiles must be ordered by increasing bm0"
)
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
problem_ctx = ProblemContext(
dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v
)
kernel_ctx = KernelContext(
tile=tile, pipeline=pipeline, mask_impl=mask_impl
)
rules = factory.get_rules()
product = get_product(receipt)
if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]):
continue
k = create_kernel(factory.arch, problem_ctx, kernel_ctx)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: SageAttnFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.render())
def write_fwd_api(
api_pool: SageAttnFwdApiPool,
autogen_dir: Path,
) -> None:
content = "".join(
[
SAGEATTN_FWD_API_HEADER,
api_pool.render("sageattn_fwd_impl"),
SAGEATTN_FWD_API_FOOTER_TEMPLATE,
]
)
update_file(autogen_dir / SAGEATTN_FWD_API_FILENAME, content)
def write_blobs(
targets: List[str],
output_dir: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
targets: List[str],
file_path: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(
targets, kernel_filter, receipt, optdim_list, mask_impl
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write(
(file_path.parent / GEN_DIR / SAGEATTN_FWD_API_FILENAME).as_posix() + "\n"
)

View File

@@ -0,0 +1,70 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import dataclasses
import os.path as path
import textwrap
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
def indent(code: str, indent: str = " ") -> str:
return textwrap.indent(code, indent)
def if_(i: int) -> str:
return "if" if i == 0 else "else if"
def check_duplicates_and_paddings(traits, trait):
"""Check
* if the traits list does not contain a trait with the same parameters;
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
for example, f, _t_, f, t cannot be before f, _f_, f, t.
"""
fields = [f.name for f in dataclasses.fields(trait)]
pad_fields = [f for f in fields if "pad" in f]
non_pad_fields = [f for f in fields if "pad" not in f]
for prev_trait in traits:
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
continue
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
raise Exception(f"Duplicate found {trait}")
# Check if the previous kernel can be incorrectly used before the current one
# for example, f, _t_, f, t cannot be before f, _f_, f, t
is_prev_more_restrictive = False
is_curr_more_restrictive = False
for f in pad_fields:
prev_pad = getattr(prev_trait, f)
pad = getattr(trait, f)
if isinstance(prev_pad, str):
prev_pad = 1000000 if prev_pad == "f" else 1
pad = 1000000 if pad == "f" else 1
elif isinstance(prev_pad, int):
prev_pad = 1000000 if prev_pad == 0 else prev_pad
pad = 1000000 if pad == 0 else pad
else:
assert False
if prev_pad < pad:
is_prev_more_restrictive = True
elif prev_pad > pad:
is_curr_more_restrictive = True
if is_prev_more_restrictive and not is_curr_more_restrictive:
raise Exception(
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
)

View File

@@ -0,0 +1,202 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "sageattn_fwd.hpp"
#include "sageattn_fwd_runner.hpp"
#include <string>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_k",
"-1",
"seqlen_k (including new key/value), -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"seqlen_q stride between 2 batches (group-mode optional).\n"
"Provide positive strides per-batch to simulate physical padding on Q.")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed, same as xformer kv_padding,\n"
"must be greater than or equal to s_k")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("qscale",
"n",
"n or 0, no scale\n"
"pt or 1, per-tensor scale\n"
"bs or 2, block scale (Q:128, KV:128)\n"
"pw or 3, per-warp scale (Q:32, KV:64)\n"
"pth or 4, per-thread scale (Q:4, KV:16)\n")
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("prec",
"fp8bf16",
"Primary: fp8bf16, i8fp8bf16, i4fp8bf16. Also bf16 (keep): pipeline validation "
"with qscale=n (no quant); not the quantized Sage product path.")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init",
"uf",
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
"\n uf or 1 - uniform random float\n nf - normalized random float"
"\n tf or 2 - trig float"
"\n tf or 3 - uniform random float, min max is the max of the type\n")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "sageattn_fwd.json", "json file name to dump results")
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataTypeConfig>
auto run(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
float scale_s = arg_parser.get_float("scale_s");
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
std::string qscale_str = arg_parser.get_str("qscale");
std::string mask_str = arg_parser.get_str("mask");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_str("timer") == std::string("gpu")};
auto json = arg_parser.get_int("json") == 1
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
: std::nullopt;
return sageattn_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
hdim_q,
hdim_v,
seqlen_qpads,
seqlen_kpads,
q_eff_lens_per_batch,
kv_eff_lens_per_batch,
i_perm,
o_perm,
scale_s,
is_v_rowmajor,
mask_str,
qscale_str,
init_method,
seed,
do_validation,
stream_config,
json);
}
int main(int argc, char* argv[])
{
try
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "bf16")
{
return run<SageAttentionFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<SageAttentionFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "i8fp8bf16")
{
return run<SageAttentionFwdI8Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "i4fp8bf16")
{
return run<SageAttentionFwdI4Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}
catch(const std::invalid_argument& e)
{
std::cerr << "Invalid argument: " << e.what() << std::endl;
return -1;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return -2;
}
}

View File

@@ -0,0 +1,173 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
from typing import List, Optional
import codegen.ops
from codegen.cmake_config import GEN_DIR
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
# Strip "sageattn_" so module sageattn_fwd registers as CLI key "fwd".
unwanted_prefix = "sageattn_"
handlers = dict(
[
(
(
op.__name__[len(unwanted_prefix) :]
if op.__name__.startswith(unwanted_prefix)
else op.__name__
),
(op.list_blobs, op.write_blobs),
)
for op in ops
]
)
assert 0 < len(handlers)
def write_blobs(
targets: List[str],
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(
targets: List[str],
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="Generate SageAttention CK_tile kernel/API blobs.",
)
parser.add_argument(
"--targets",
default="gfx9,gfx950",
required=False,
help="list of GPU targets, separated by comma.",
)
parser.add_argument(
"-a",
"--api",
default="fwd",
required=False,
help="Codegen API key(s), comma-separated (e.g. fwd -> module codegen.ops.sageattn_fwd).",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory",
)
parser.add_argument(
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default="",
required=False,
help="filter out kernels that need to generate, using fnmatch module",
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic",
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="Codegen receipt index. SageAttention forward currently uses receipt 0 only; "
"the value is passed through to ops (see get_product in sageattn_fwd.py).",
)
parser.add_argument(
"--optdim",
default="-1",
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice. "
"e.g. --optdim=32,64,128,256",
)
args = parser.parse_args()
targets = args.targets.split(",")
api_list = args.api.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
if args.list_blobs is not None:
list_blobs(
targets,
args.list_blobs,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)
else:
write_blobs(
targets,
args.output_dir,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)

View File

@@ -0,0 +1,169 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
mask_top_left,
mask_bottom_right,
window_generic,
};
struct mask_info
{
mask_enum type;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::mask_top_left)
os << "t(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "b(" << left << ":" << right << ")";
else
{
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
{
ck_tile::index_t x_total = seqlen_k;
ck_tile::index_t y_total = seqlen_q;
mask_info tmp;
tmp.seqlen_q = seqlen_q;
tmp.seqlen_k = seqlen_k;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
if(t == "xt" || t == "xb")
{
// xformer style sliding window attn from top-left
ck_tile::index_t window_size = std::stoi(v);
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, 0, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = left_size;
tmp.right = right_size;
}
else if(t == "t" || t == "b" || t == "g")
{
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
{
throw std::invalid_argument("invalid mask value: " + str);
}
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, 0, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, 0, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.type = mask_enum::window_generic;
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
}
else
{
throw std::invalid_argument("invalid mask value: " + str);
}
}
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
}
else if(str == "1" || str == "t")
{
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(str == "2" || str == "b")
{
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
}
else
{
throw std::invalid_argument("invalid mask value: " + str);
}
return tmp;
}
std::size_t get_unmaskarea() const
{
if(type == mask_enum::no_mask)
return static_cast<std::size_t>(seqlen_q) * seqlen_k;
std::size_t area = 0;
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
{
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
if(x_end > x_start)
{
area += (x_end - x_start);
}
}
return area;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
};

View File

@@ -0,0 +1,74 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
// keep sync with BlockSageAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale = 2,
perwarp = 3,
perthread = 4,
};
struct quant_scale_info
{
quant_scale_enum type;
void serialize(std::ostream& os) const
{
if(type == quant_scale_enum::no_scale)
os << "n";
else if(type == quant_scale_enum::pertensor)
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::perwarp)
os << "pw";
else if(type == quant_scale_enum::perthread)
os << "pth";
}
static quant_scale_info decode(std::string str)
{
quant_scale_info info{quant_scale_enum::no_scale};
if(str == "n" || str == "0")
{
info.type = quant_scale_enum::no_scale;
}
else if(str == "pt" || str == "1")
{
info.type = quant_scale_enum::pertensor;
}
else if(str == "bs" || str == "2")
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "pw" || str == "3")
{
info.type = quant_scale_enum::perwarp;
}
else if(str == "pth" || str == "4")
{
info.type = quant_scale_enum::perthread;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
{
qsi.serialize(os);
return os;
}
};

View File

@@ -0,0 +1,384 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/sageattn.hpp"
#include "mask.hpp"
#include "quant.hpp"
#include <type_traits>
#include <utility>
#include <variant>
// SageAttention data type configs (must match codegen FWD_DTYPE_MAP + SageAttentionFwdTypeConfig)
struct SageAttentionFwdFp16
{
};
struct SageAttentionFwdBf16
{
};
struct SageAttentionFwdFp8Bf16
{
};
struct SageAttentionFwdI8Fp8Bf16
{
};
struct SageAttentionFwdI4Fp8Bf16
{
};
template <typename DataType>
struct SageAttentionFwdTypeConfig;
// fp16/bf16 are not Sage product dtypes; bf16 is intentionally kept in tile_example_sageattn_fwd
// for pipeline validation with qscale=n (no quant).
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using ScaleType = float; // scale type for quantized inputs
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using ScaleType = float; // scale type for quantized inputs
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp8Bf16>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using ScaleType = float; // scale type for quantized inputs
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdI8Fp8Bf16>
{
using QDataType = ck_tile::int8_t;
using KDataType = ck_tile::int8_t;
using VDataType = ck_tile::fp8_t;
using ScaleType = float; // scale type for Q and K
using SaccDataType = float; // Keep as float for softmax computation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // P in FP8 for 2nd gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct SageAttentionFwdTypeConfig<SageAttentionFwdI4Fp8Bf16>
{
using QDataType = ck_tile::pk_int4_t;
using KDataType = ck_tile::pk_int4_t;
using VDataType = ck_tile::fp8_t;
using ScaleType = float;
using SaccDataType = float;
using SMPLComputeDataType = float;
using PDataType = ck_tile::fp8_t;
using OaccDataType = float;
using ODataType = ck_tile::bf16_t;
};
struct SageAttnMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct sageattn_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* q_descale_ptr;
const void* k_descale_ptr;
const void* v_descale_ptr;
void* o_ptr;
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
// BLOCKSCALE parameters
ck_tile::index_t nhead_stride_q_descale = 0;
ck_tile::index_t nhead_stride_k_descale = 0;
ck_tile::index_t nhead_stride_v_descale = 0;
ck_tile::index_t batch_stride_q_descale = 0;
ck_tile::index_t batch_stride_k_descale = 0;
ck_tile::index_t batch_stride_v_descale = 0;
ck_tile::index_t block_scale_size_q = 0;
ck_tile::index_t block_scale_size_k = 0;
const void* block_scale_seqstart_q_ptr = nullptr;
const void* block_scale_seqstart_k_ptr = nullptr;
};
template <typename SageAttnKernel>
auto sageattn_fwd_create_kargs_and_grids(sageattn_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(SageAttnKernel::kIsGroupMode)
{
return SageAttnKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_v_descale,
args.block_scale_size_q,
args.block_scale_size_k,
args.block_scale_seqstart_q_ptr,
args.block_scale_seqstart_k_ptr,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.min_seqlen_q,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{ // create batch mode kernel arguments
return SageAttnKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.nhead_stride_q_descale,
args.nhead_stride_k_descale,
args.nhead_stride_v_descale,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_o,
args.batch_stride_q_descale,
args.batch_stride_k_descale,
args.batch_stride_v_descale,
args.block_scale_size_q,
args.block_scale_size_k,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
}();
if constexpr(SageAttnKernel::kIsGroupMode)
{
dim3 grids = SageAttnKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids = SageAttnKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
}
// this is used to pattern-match internal kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockSageAttnPipelineEnum SageAttnPipelineEnum_,
typename AttnMask_,
ck_tile::BlockSageAttentionQuantScaleEnum QScaleEnum_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
struct sageattn_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto SageAttnPipelineEnum = SageAttnPipelineEnum_;
using AttnMask = ck_tile::remove_cvref_t<AttnMask_>;
static constexpr auto QScaleEnum = QScaleEnum_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_, typename Arch = void>
float sageattn_fwd_(const ck_tile::stream_config&, sageattn_fwd_args);
// This is the public API, will be generated by script
struct sageattn_fwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
quant_scale_enum qscale_type;
bool skip_min_seqlen_q = false;
// TODO: padding check is inside this api
};
float sageattn_fwd(sageattn_fwd_traits, sageattn_fwd_args, const ck_tile::stream_config&);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,162 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
#
# SageAttention forward smoke tests - structure mirrors
# example/ck_tile/01_fmha/script/smoke_test_fwd.sh
#
# Run from the ComposableKernel *build* directory (after ninja), same as FMHA:
# cd build && ninja tile_example_sageattn_fwd
# bash ../example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
#
# Optional: VERBOSE=1 enables bash -x. CURR_FAILS_FILE / KNOWN_FAILS_FILE override fail logs.
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
EXE_NAME=tile_example_sageattn_fwd
EXE="$(find . -name "$EXE_NAME" -type f 2>/dev/null | head -n 1)"
KNAME=1
GPU_arch=${GPU_arch:-}
if [ -z "$GPU_arch" ]; then
GPU_arch=$(rocminfo 2>/dev/null | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}' || echo "unknown")
fi
export CK_WARMUP=0
export CK_REPEAT=1
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"sageattn_fwd_fails_${GPU_arch}.txt"}
rm -f "$CURR_FAILS_FILE"
touch "$CURR_FAILS_FILE"
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/sageattn_fwd_known_fails_${GPU_arch}.txt"}
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then
echo "ERROR: $EXE_NAME not found under cwd ($(pwd)). Build with: ninja $EXE_NAME" >&2
exit 1
fi
run_exe() {
set +e
$EXE "$@"
local ret=$?
if [ $ret -ne 0 ]; then
echo "$EXE_NAME $*" >>"$CURR_FAILS_FILE"
fi
set -e
}
# Core FP8xBF16 cases aligned with FMHA smoke_test_fwd.sh (lines 80-87): batch/group shapes,
# masks, GQA, short seqlen, k-only pad. Sweeps blockscale (2) vs per-warp (3) and layouts.
run_fp8bf16_smoke() {
local qscale
local perm
for qscale in 2 3; do
for perm in 0 1; do
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=2 -h_k=1 -d=128 -d_v=128 -s=55 -s_k=256 \
-mask=1
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=3 -d=128 -s=100 -s_k=51 -mask=0
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=99 -s_k=256 \
-mask=1
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1024 -s_k=256 \
-mask=2
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=3 -s_k=99 -mask=2
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=3 -h=2 -h_k=1 -d=128 -s=200 -s_k=520 \
-mask=t:128,30
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -s=99 -s_k=32 -mask=b:4,35
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=33 -s_k=0 -mask=2
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 \
-s_kpad=32 -mask=2
done
done
}
# Extra FP8: explicit causal string, xformer window, per-tensor / per-thread quant, V col-major.
run_fp8bf16_extras() {
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=t:-1,0
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=1 -operm=1 -vlayout=c -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=2 -h=4 -d=128 -s=256 -s_k=256 -mask=t
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=256 -s_k=256 -mask=xt:64
run_exe -prec=fp8bf16 -init=3 -qscale=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=0
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=64 -s_k=64 -mask=0
}
# Group mode + physical padding (same intent as FMHA run_padding_smoke_tests, Sage-only flags).
run_group_and_padding_smoke() {
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
# group + PERTHREAD: block_scale_seqstart_* must be allocated (same as bs/pw)
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=1 -b=4 -h=8 -h_k=8 -d=128 -s=1024,768,512,256 -s_k=1024,768,512,256 \
-mask=0 -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=0 \
-q_eff_lens=960,512,384,256 -kv_eff_lens=960,512,384,256
}
# BF16 (no quant): pipeline sanity only; not a shipped Sage mode (see example --help prec).
run_bf16_pipeline_smoke() {
run_exe -prec=bf16 -init=1 -qscale=n -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
run_exe -prec=bf16 -init=1 -qscale=n -iperm=1 -operm=1 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=4 -h_k=1 -d=128 -s=256 -s_k=128 -mask=t:32,32
}
# int8 / int4 x fp8xbf16 (hdim divisible by 8 for int4)
run_int_quant_smoke() {
run_exe -prec=i8fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
run_exe -prec=i4fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=t
}
if [ "${VERBOSE:-0}" = 1 ]; then
set -x
fi
run_fp8bf16_smoke
run_fp8bf16_extras
run_group_and_padding_smoke
run_bf16_pipeline_smoke
run_int_quant_smoke
set +x
new_fails_count=0
known_fails_count=0
if [ -f "$KNOWN_FAILS_FILE" ]; then
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
while IFS= read -r line; do
if grep -Fxq "$line" "$KNOWN_FAILS_FILE"; then
echo "Known fail: $line"
known_fails_count=$((known_fails_count + 1))
else
echo "New fail: $line"
new_fails_count=$((new_fails_count + 1))
fi
done <"$CURR_FAILS_FILE"
else
new_fails_count=$(wc -l <"$CURR_FAILS_FILE")
echo "No known fails file, all fails ($new_fails_count) are new:"
if [ "$new_fails_count" -gt 0 ]; then
cat "$CURR_FAILS_FILE"
fi
fi
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
exit $((new_fails_count != 0))

View File

@@ -0,0 +1,254 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ck_tile/core/container/span.hpp"
enum class mode_enum
{
batch = 0,
group
};
inline std::ostream& operator<<(std::ostream& stream, mode_enum mode)
{
return stream << (mode == mode_enum::batch ? "batch" : "group");
}
template <typename T>
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
inline std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
{
std::vector<int32_t> seqstarts = {0};
for(int32_t seqlen : seqlens)
{
seqstarts.push_back(seqstarts.back() + seqlen);
}
assert(seqstarts.size() == seqlens.size() + 1);
return seqstarts;
}
template <typename RandomEngine>
std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min, // if not negative, clamp min
int32_t seqlen_max, // if not negative, clamp max
RandomEngine& random_engine)
{
assert(0 < count);
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(mode == mode_enum::group && 1 < count)
{
using size_type = std::vector<int32_t>::size_type;
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
auto next_step = std::bind(step_dist, std::ref(random_engine));
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{
const size_type to_decrease = next_idx();
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if(seqlens[to_decrease] == seqlen_min)
{
continue;
}
const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlens[to_increase] >= seqlen_max)
{
continue;
}
--seqlens[to_decrease];
++seqlens[to_increase];
}
}
return seqlens;
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int, typename RandomEngine>
auto randint(Int low,
Int high,
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
{
std::uniform_int_distribution<Int> dist(low, high);
return dist(random_engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator, typename RandomEngine>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
{
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(random_engine); });
}
/*
* generate missing values in *_val randomly when the number of values is smaller than batch
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& q_pad_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
RandomEngine& random_engine)
{
if(mode == mode_enum::batch)
{
ck_tile::index_t q = q_val[0];
ck_tile::index_t k = k_val[0];
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = [&] {
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && need_append_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
random_engine);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
else
{
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
std::vector<ck_tile::index_t> s_qpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
ck_tile::index_t q = q_val[idx];
ck_tile::index_t k =
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
ck_tile::index_t kp =
k_pad_val.empty()
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
ck_tile::index_t qp =
q_pad_val.empty()
? -1
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
s_qpad.push_back(qp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
}
if(idx < batch)
{
auto rem_q =
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
auto rem_k = generate_seqlens(
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
}
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
}
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
RandomEngine& random_engine)
{
std::iota(first, last, value);
std::shuffle(first, last, random_engine);
}

View File

@@ -31,6 +31,7 @@ add_subdirectory(38_block_scale_gemm)
add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)
add_subdirectory(42_mx_gemm)
add_subdirectory(49_sageattention)
add_subdirectory(50_sparse_attn)
add_subdirectory(51_tile_distr_enc_reg_map)
if(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS)

View File

@@ -530,4 +530,10 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;
} // namespace ck_tile

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockSageAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
PERWARP = 3,
PERTHREAD = 4,
};
template <BlockSageAttentionQuantScaleEnum>
struct BlockSageAttentionQuantScaleEnumToStr;
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::NO_SCALE>
{
static constexpr const char* name = "";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTENSOR>
{
static constexpr const char* name = "pertensor";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::BLOCKSCALE>
{
static constexpr const char* name = "blockscale";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERWARP>
{
static constexpr const char* name = "perwarp";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTHREAD>
{
static constexpr const char* name = "perthread";
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,29 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockSageAttnPipelineEnum
{
QRKSVS = 0,
QRKSVS_ASYNC,
};
template <BlockSageAttnPipelineEnum>
struct BlockSageAttnPipelineEnumToStr;
template <>
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS>
{
static constexpr const char* name = "qr";
};
template <>
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS_ASYNC>
{
static constexpr const char* name = "qr_async";
};
} // namespace ck_tile

View File

@@ -0,0 +1,60 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockSageAttnShape_,
bool kIsGroupMode_,
typename AttentionVariant_,
typename AttnMask_,
typename Traits_>
struct BlockSageAttnPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockSageAttnShape = remove_cvref_t<BlockSageAttnShape_>;
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
using AttnMask = remove_cvref_t<AttnMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = BlockSageAttnShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockSageAttnShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockSageAttnShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
/// Must match host scale tensor layout (same values as TileSageAttnTraits for Sage kernels).
static constexpr index_t kBlockScaleSizeQ = Traits::kBlockScaleSizeQ;
static constexpr index_t kBlockScaleSizeK = Traits::kBlockScaleSizeK;
};
} // namespace ck_tile

View File

@@ -0,0 +1,861 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSDefaultPolicy>
struct BlockSageAttentionPipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using QGemmDataType = SageAttnQKGemmQDataType<Problem>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
static_assert(std::is_same_v<PDataType, VDataType>,
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<PDataType, fp8_t>,
"SageAttention pipeline requires PDataType = fp8_t");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<VDataType, fp8_t>,
"SageAttention pipeline requires VDataType = fp8_t");
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
// FP8 softmax shift constants to map softmax output into representable FP8 range
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
}
}
}();
static constexpr const char* name = "qr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
AttnMask mask,
PositionEncoding /*position_encoding*/,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
[[maybe_unused]] const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// K tile in LDS
KLdsDataType* k_lds_ptr = static_cast<KLdsDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window_reg =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
auto q = load_tile(q_dram_window_reg);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType =
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
SaccBlockTileType,
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(start, end);
}();
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return o_acc;
}
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{kv_load_start, 0});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, kv_load_start},
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = [&]() {
if constexpr(std::is_same_v<QDataType, QGemmDataType>)
return tile_elementwise_in(q_element_func, q);
else
{
auto q_tile_tmp = make_static_distributed_tensor<QGemmDataType>(
Policy::template MakeQRegTileDistribution<Problem>());
constexpr index_t kPackedSize = numeric_traits<QDataType>::PackedSize;
constexpr index_t kUnaryOpSize = 8;
static_assert(std::is_same_v<QDataType, ck_tile::pk_int4_t>);
static_assert(kPackedSize == 2);
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() ==
decltype(q)::get_thread_buffer_size() * kPackedSize);
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
using RawQType = typename QDataType::type;
using SrcVectorType = ext_vector_t<RawQType, kUnaryOpSize / kPackedSize>;
using DstVectorType = ext_vector_t<QGemmDataType, kUnaryOpSize>;
constexpr index_t kVecSize =
decltype(q_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
static_assert(decltype(q)::get_thread_buffer_size() ==
kVecSize * (kUnaryOpSize / kPackedSize));
const element_wise::PassThroughPack8 pass_through_pack8{};
static_for<0, kVecSize, 1>{}([&](auto i) {
pass_through_pack8(
q_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(i),
q.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
return q_tile_tmp;
}
}();
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
// Use compile-time conditional for group barrier sequence
// (No runtime lambda selection)
auto schedule_gemm0 = [] {
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmConfig =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) *
(kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
if constexpr(get_warp_size() == 64 && kQKHeaddim == 256)
{
static_assert(NumMfmaInsts % 8 == 0);
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
});
}
};
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
static_assert(get_warp_size() % kGemm0MPerWarp == 0);
constexpr index_t kWarpSz = get_warp_size();
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
// indexing)
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
// main loop
do
{
float k_descale = 1.0f;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
k_descale = k_descale_ptr[kv_idx];
}
constexpr index_t kNumKScalesPW =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
? kN0 / Problem::kBlockScaleSizeK
: 1;
constexpr index_t kNumKScalesPT =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
? kN0 / Problem::kBlockScaleSizeK / 2
: 1;
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
}
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
}
// STAGE 1, QK gemm
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto s_acc_gemm = SaccBlockTileType{};
const auto store_k_block_tile_to_lds = [&](const auto& k_block_tile_) {
if constexpr(std::is_same_v<KDataType, KLdsDataType>)
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile_));
else
{
auto k_block_tile_tmp = make_static_distributed_tensor<KLdsDataType>(
k_dram_window.get_tile_distribution());
using KBlockTileType = remove_cvref_t<decltype(k_block_tile_)>;
constexpr index_t kPackedSize = numeric_traits<KDataType>::PackedSize;
constexpr index_t kUnaryOpSize = 8;
static_assert(std::is_same_v<KDataType, ck_tile::pk_int4_t>);
static_assert(kPackedSize == 2);
static_assert(decltype(k_block_tile_tmp)::get_thread_buffer_size() ==
KBlockTileType::get_thread_buffer_size() * kPackedSize);
static_assert(
decltype(k_block_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
using RawKType = typename KDataType::type;
using SrcVectorType = ext_vector_t<RawKType, kUnaryOpSize / kPackedSize>;
using DstVectorType = ext_vector_t<KLdsDataType, kUnaryOpSize>;
constexpr index_t kVecSize =
decltype(k_block_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
static_assert(KBlockTileType::get_thread_buffer_size() ==
kVecSize * (kUnaryOpSize / kPackedSize));
const element_wise::PassThroughPack8 pass_through_pack8{};
static_for<0, kVecSize, 1>{}([&](auto i) {
pass_through_pack8(
k_block_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(
i),
k_block_tile_.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
store_tile(k_lds_window, k_block_tile_tmp);
}
};
auto k_block_tile = load_tile(k_dram_window);
{
move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc_gemm); // initialize C
store_k_block_tile_to_lds(k_block_tile);
k_block_tile = load_tile(k_dram_window);
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(s_acc_gemm,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
store_k_block_tile_to_lds(k_block_tile); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc_gemm,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
store_k_block_tile_to_lds(k_block_tile);
block_sync_lds();
gemm_0(s_acc_gemm,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
schedule_gemm0();
}
// Convert GEMM output to SaccDataType for softmax (if needed)
auto s_acc = [&]() {
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
{
return s_acc_gemm; // No conversion needed (e.g., float -> float)
}
else
{
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
}
}();
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// PERTHREAD: kBlockScaleSizeK=16
// The s_acc tile distribution is determined by
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
// each thread processes exactly 16 consecutive elements in the K dimension. This
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
// elements to K scale indices.
static_assert(Problem::kBlockScaleSizeK == 16,
"PERTHREAD: kBlockScaleSizeK must be 16");
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures the distribution has 16 consecutive K elements per
// thread
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERTHREAD requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"16 consecutive K elements");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPT] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
const index_t scale_idx = col_offset >> 4;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
// grouping is correct
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERWARP requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"correct K element grouping");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPW] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
// elements Divide by 32 (>>5) to map to K scale groups
// (kBlockScaleSizeK=64)
const index_t scale_idx = col_offset >> 5;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else
{
// dequant: combine q_descale (in s_acc_element_func) with k_descale
auto s_acc_element_func_ = [&]() {
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
return s_acc_element_func * k_descale;
}
else
return s_acc_element_func;
}();
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
}
// STAGE 2, scale_s, mask, softmax
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(AttnMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
// For BLOCKSCALE: precompute (m - shift) once per row
// exp2(s - m + shift) = exp2(s - (m - shift)); pertensor path uses scale_s on s,m
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
row_max -= OCP_FP8_SHIFT; // for else branch
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
}
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_new = get_validated_m(m[i_idx]);
auto row_max = scale_s * m_new;
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
// Update l and rescale o_acc
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
// Apply per-channel v_descale after the loop (before normalization)
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
move_tile_window(v_dram_window, {0, kK1});
});
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
} while(++i_total_loops < num_total_loop);
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
// before normalization)
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
block_sync_lds();
// V is col-major, each column (channel) has its own scale
// o_acc shape: [M0, N1] where N1 is hdim_v
// v_descale_ptr points to per-channel scales [hdim_v]
// Load v_descale to LDS for better memory access pattern
// Reuse K/V LDS space (they're no longer needed)
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
// Cooperatively load v_descale to LDS
const index_t num_threads = kBlockSize;
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
{
v_descale_lds[i] = v_descale_ptr[i];
}
block_sync_lds();
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// Get the global tile index for the N1 (channel) dimension
const auto tile_idx = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
const index_t channel_idx = tile_idx.at(number<1>{});
const float v_scale = v_descale_lds[channel_idx];
o_acc(i_j_idx) *= v_scale;
});
});
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
// When masking, the denominator can be zero; guard the normalization
// so we do not divide by zero after a fully masked row.
if constexpr(AttnMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
AttnMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
[[maybe_unused]] const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
q_descale_value);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,873 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy>
struct BlockSageAttentionPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
static_assert(std::is_same_v<PDataType, VDataType>,
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<PDataType, fp8_t>,
"SageAttention pipeline requires PDataType = fp8_t");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<VDataType, fp8_t>,
"SageAttention pipeline requires VDataType = fp8_t");
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto QScaleEnum = Problem::QScaleEnum;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
// FP8 softmax shift constants to map softmax output into representable FP8 range
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
return 2;
}
else if constexpr(kQKHeaddim <= 192)
{
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
AttnMask mask,
PositionEncoding /*position_encoding*/,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
[[maybe_unused]] const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumKVLdsBuffers>{});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType =
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
SaccBlockTileType,
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(start, end);
}();
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{kv_load_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = bool_constant<false>{};
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, kv_load_start},
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
static_assert(kGemm0MPerWarp == 32);
constexpr index_t kWarpSz = get_warp_size();
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
// indexing)
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
// main loop
do
{
float k_descale = 1.0f;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
k_descale = k_descale_ptr[kv_idx];
}
constexpr index_t kNumKScalesPW =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
? kN0 / Problem::kBlockScaleSizeK
: 1;
constexpr index_t kNumKScalesPT =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
? kN0 / Problem::kBlockScaleSizeK / 2
: 1;
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
}
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
}
// STAGE 1, QK gemm
auto s_acc_gemm = SaccBlockTileType{};
clear_tile(s_acc_gemm); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc_gemm,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
s_acc_gemm,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// Convert GEMM output to SaccDataType for softmax (if needed)
auto s_acc = [&]() {
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
{
return s_acc_gemm; // No conversion needed (e.g., float -> float)
}
else
{
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
}
}();
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// PERTHREAD: kBlockScaleSizeK=16
// The s_acc tile distribution is determined by
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
// each thread processes exactly 16 consecutive elements in the K dimension. This
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
// elements to K scale indices.
static_assert(Problem::kBlockScaleSizeK == 16,
"PERTHREAD: kBlockScaleSizeK must be 16");
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures the distribution has 16 consecutive K elements per
// thread
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERTHREAD requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"16 consecutive K elements");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPT] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
const index_t scale_idx = col_offset >> 4;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
// grouping is correct
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERWARP requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"correct K element grouping");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPW] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
// elements Divide by 32 (>>5) to map to K scale groups
// (kBlockScaleSizeK=64)
const index_t scale_idx = col_offset >> 5;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else
{
// dequant: combine q_descale (in s_acc_element_func) with k_descale
auto s_acc_element_func_ = [&]() {
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
return s_acc_element_func * k_descale;
}
else
return s_acc_element_func;
}();
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
}
// STAGE 2, scale_s, mask, softmax
// logits_soft_cap is always disabled
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store
// Only needed when K tail and V use the same LDS buffer
if constexpr(LdsSeq.at(number<k0_loops - 1>{}) == LdsSeq.at(number<k0_loops>{}))
{
__builtin_amdgcn_s_barrier();
}
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(AttnMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
// For BLOCKSCALE: precompute (m - shift) once per row
// exp2(s - m + shift) = exp2(s - (m - shift))
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
row_max -= OCP_FP8_SHIFT; // for else branch
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
}
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// logits_soft_cap is always disabled
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_new = get_validated_m(m[i_idx]);
auto row_max = scale_s * m_new;
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
// Update l and rescale o_acc
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
const auto p = [&]() {
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
// For fp32 to fp16,
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#else
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
#endif
}();
// STAGE 3, KV gemm
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
// Apply per-channel v_descale after the loop (before normalization)
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
// before normalization)
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
block_sync_lds();
// V is col-major, each column (channel) has its own scale
// o_acc shape: [M0, N1] where N1 is hdim_v
// v_descale_ptr points to per-channel scales [hdim_v]
// Load v_descale to LDS for better memory access pattern
// Reuse K/V LDS space (they're no longer needed)
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
// Cooperatively load v_descale to LDS
const index_t num_threads = kBlockSize;
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
{
v_descale_lds[i] = v_descale_ptr[i];
}
block_sync_lds();
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// Get the global tile index for the N1 (channel) dimension
const auto tile_idx = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
const index_t channel_idx = tile_idx.at(number<1>{});
const float v_scale = v_descale_lds[channel_idx];
o_acc(i_j_idx) *= v_scale;
});
});
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(AttnMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
AttnMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
q_descale_value);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,18 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy =
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile

View File

@@ -0,0 +1,857 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
namespace ck_tile {
template <typename T>
CK_TILE_HOST_DEVICE static constexpr index_t GetPackedSize()
{
return numeric_traits<remove_cvref_t<T>>::PackedSize;
}
template <typename T>
CK_TILE_HOST_DEVICE static constexpr index_t GetLogicalVectorSize(index_t bytes)
{
return (bytes / sizeof(remove_cvref_t<T>)) * GetPackedSize<T>();
}
template <typename Problem>
using SageAttnQKGemmQDataType =
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::QDataType>>,
fp8_t,
remove_cvref_t<typename Problem::QDataType>>;
template <typename Problem>
using SageAttnQKGemmKDataType =
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::KDataType>>,
fp8_t,
remove_cvref_t<typename Problem::KDataType>>;
template <bool QLoadOnce_>
struct BlockSageAttnPipelineQRCustomPolicy;
template <>
struct BlockSageAttnPipelineQRCustomPolicy</* QLoadOnce = */ true>
{
static constexpr bool QLoadOnce = true;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return 0;
}
// TODO: GetAlignment*() currently didn't consider if need padding or not
// so in pipeline still need check padding requirement
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = GetLogicalVectorSize<typename Problem::QDataType>(16);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution<
Problem::BlockSageAttnShape::kM0,
Problem::BlockSageAttnShape::kSubQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using QKGemmQDataType = SageAttnQKGemmQDataType<Problem>;
using QKGemmKDataType = SageAttnQKGemmKDataType<Problem>;
// int8 MFMA accumulates to int32, but SaccDataType is float for softmax
using GemmAccDataType =
std::conditional_t<(std::is_same_v<QKGemmQDataType, int8_t> ||
std::is_same_v<QKGemmQDataType, signed char>) &&
(std::is_same_v<QKGemmKDataType, int8_t> ||
std::is_same_v<QKGemmKDataType, signed char>),
int32_t,
typename Problem::SaccDataType>;
using GemmProblem =
BlockGemmProblem<QKGemmQDataType,
QKGemmKDataType,
GemmAccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
Problem::BlockSageAttnShape::kN0,
Problem::BlockSageAttnShape::kK0>,
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
typename Problem::BlockSageAttnShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(get_warp_size() == 64 && std::is_same_v<QKGemmQDataType, fp8_t> &&
std::is_same_v<QKGemmKDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
// TODO: hard coded here. Otherwise, it produces incorrect results
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
swizzle_factor>{};
}
else if constexpr(get_warp_size() == 64 &&
(std::is_same_v<QKGemmQDataType, int8_t> ||
std::is_same_v<QKGemmQDataType, signed char>) &&
(std::is_same_v<QKGemmKDataType, int8_t> ||
std::is_same_v<QKGemmKDataType, signed char>))
{
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
// Use special int8 MFMA with K iteration (similar to FP8)
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<
swizzle_factor>{};
}
else
{
constexpr bool SwizzleA =
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32;
return WarpGemmDispatcher<
QKGemmQDataType,
QKGemmKDataType,
GemmAccDataType,
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}),
true, // TransposeC
SwizzleA>{};
}
}();
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
QKGemmQDataType,
QKGemmKDataType,
GemmAccDataType,
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
};
// This pipeline is qkv all located in LDS
template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchK_, index_t NumPrefetchV_>
struct BlockSageAttnPipelineQRKSVSCustomPolicy : BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>
{
static constexpr bool AsyncCopy = AsyncCopy_;
static constexpr index_t NumPrefetchK = NumPrefetchK_;
static constexpr index_t NumPrefetchV = NumPrefetchV_;
static constexpr index_t NumKVLdsBuffers = max(NumPrefetchK, NumPrefetchV);
using QXPolicy = BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>;
template <index_t k_prefetches_, index_t v_prefetches_, index_t k_loops_, index_t v_loops_>
struct LdsBufferSequence
{
static constexpr index_t num_lds_buffers_ = max(k_prefetches_, v_prefetches_);
static constexpr index_t ceil_ = ((v_loops_ - 1) / num_lds_buffers_) * num_lds_buffers_;
// for qr_ks_vs_async, the Lds buffer assigned to last gemm_1 iteration of V should not
// overlap with the Lds buffers used by first two gemm_0 iterations of K
static constexpr auto Make()
{
// ensure v_loop_-1 is assigned to num_lds_buffers-1
return transform_sequences(
[&](auto i) {
if(i < k_loops_)
return i % num_lds_buffers_;
else
return ((num_lds_buffers_ - 1) + (i - k_loops_ + ceil_ - (v_loops_ - 1))) %
num_lds_buffers_;
},
typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{});
};
using type = remove_cvref_t<decltype(Make())>;
};
// clang-format off
template<> struct
LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 4, 2> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 2, 4> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
// clang-format on
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetLdsBufferSequence()
{
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
constexpr index_t kN0 = BlockSageAttnShape::kN0;
constexpr index_t kK0 = BlockSageAttnShape::kK0;
constexpr index_t kK1 = BlockSageAttnShape::kK1;
constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
// TODO: this is for 3d layout
using KDataType = SageAttnQKGemmKDataType<Problem>;
return GetLogicalVectorSize<KDataType>(16);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
if constexpr(AsyncCopy)
{
#if defined(__gfx950__)
constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4
#else
constexpr index_t MaxLoadSizeInBytes = 4; // dword
#endif
return GetLogicalVectorSize<KDataType>(MaxLoadSizeInBytes);
}
else
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
return kMaxVecLoad;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
else
{
return kMaxVecLoad;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType);
return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
// this function assume K/V can share smem
constexpr index_t SingleKSize = [&]() {
if constexpr(!AsyncCopy)
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
}
else
{
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack;
static_assert(WarpSize * KVector >= kKPerBlock &&
WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t BufferSize =
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKVLdsBuffers>{}, // num_buffers
number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<BufferSize>{},
number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumKVLdsBuffers>{},
number<NumIssues>{},
number<LaneGroups>{},
number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(
number<NumKVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
// TODO: assume Q is in register
// TODO: assume K and V share smem buffers
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
constexpr index_t single_smem_size =
GetSingleSmemElementSpaceSize<Problem>() * sizeof(KLdsDataType);
return QXPolicy::template GetSmemSizeQ<Problem>() + single_smem_size * NumKVLdsBuffers;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeKV<Problem>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{
if constexpr(!AsyncCopy)
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
constexpr index_t N2 = NumWarps;
constexpr index_t K0 = LanesPerK;
constexpr index_t K1 = KVector;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K3 = total_pixels / N1;
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
static_assert(kNPerBlock % 16 == 0);
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 2, 1>, // N0 K2 N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, // N0 K1
sequence<0, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock)
{
return dstr;
}
else
{
static_assert(kKPerBlock % 16 == 0);
constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0_m = kKPerBlock / kKPerIter;
constexpr index_t K2 = 2;
constexpr index_t K1_m = kKPerIter / K2;
constexpr index_t N2_m = get_warp_size() / K1_m;
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
constexpr auto dstr_m = make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>, // K0 N0 K2
sequence<0, 0, 2>>{});
static_assert(container_reduce(dstr_m.get_lengths(),
std::multiplies<index_t>{},
1) == kNPerBlock * kKPerBlock);
return dstr_m;
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor()
{
// This descriptor only used when V layout is seqlen * hdim
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
static_assert(kNPerBlock % 16 == 0);
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 1, 2>, // N0 K2 <-> N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
Problem::BlockSageAttnShape::kN1,
Problem::BlockSageAttnShape::kK1>,
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
typename Problem::BlockSageAttnShape::Gemm1WarpTile>>;
auto warp_gemm = [&]() {
if constexpr(get_warp_size() == 64 &&
std::is_same_v<typename Problem::PDataType, fp8_t> &&
std::is_same_v<typename Problem::VDataType, fp8_t> &&
std::is_same_v<typename Problem::OaccDataType, float>)
{
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}) == 32);
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
}
else
{
return WarpGemmDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}),
true>{};
}
}();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
namespace ck_tile {
using BlockSageAttentionPipelineQRKSVSDefaultPolicy =
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
} // namespace ck_tile

View File

@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t Headdim>
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
{
if constexpr(Headdim == 48)
return 48;
else if constexpr(Headdim == 80)
return 96;
else if constexpr(Headdim == 96)
return 128;
else if constexpr(Headdim == 160)
return 256;
else if constexpr(Headdim == 192)
return 192;
else if constexpr(is_power_of_two_integer(Headdim))
return Headdim;
else
static_assert(Headdim == 0,
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
};
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
bool IsVLayoutRowMajor_>
struct TileSageAttnShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* padding for hdim_q */,
bool kPadHeadDimV_ /* padding for hdim_v */,
BlockSageAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
struct TileSageAttnTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto QScaleEnum = QScaleEnum_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
/// Tokens per Q/K descale along seqlen. Fine-to-coarse: PERTHREAD, PERWARP, then 128 for Q
/// (BLOCKSCALE / no_scale / pertensor). K: PERWARP 64, BLOCKSCALE 128, else 128.
static constexpr index_t kBlockScaleSizeQ =
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 4
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 32
: 128;
static constexpr index_t kBlockScaleSizeK =
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 16
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 64
: 128;
};
} // namespace ck_tile

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp"
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"