Files
composable_kernel/example/ck_tile/49_sageattention/CMakeLists.txt
ltqin de0a61e5c2 [rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)
[CK_TILE] Add SageAttention v2 forward kernel with
 multi-granularity quantization (#6574)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

Add a CK_TILE forward kernel implementing [SageAttention
v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that
applies multi-granularity quantization to Q/K/V before computing
attention, trading minimal accuracy loss for higher throughput on
low-precision hardware.

### Quantization design

| Tensor | Supported data types | Scale granularity options |
|--------|---------------------|--------------------------|
| Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(32 tokens), per-thread (4 tokens) |
| K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(64 tokens), per-thread (16 tokens) |
| V | fp8 | per-channel (always) |
| O | bf16 | — |

Three precision combinations are supported: `fp8/bf16` (QKV fp8, O
bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK
int4, V fp8, O bf16).

### Architecture support

- **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set
- **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for
fp8-family dtypes)

### Implementation

- Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC`
(async copy)
- Masking support: no mask, causal (top-left / bottom-right), and
generic windowed
- Batch and group (variable-length) modes
- Head dimension: d=128, d_v=128
- Python codegen under `example/ck_tile/49_sageattention/codegen/`
generates kernel instances per target/dtype/tile combination
- Smoke tests included via `tile_example_sageattn_fwd`

### Test commands

\`\`\`bash
# fp8 QKV
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=fp8bf16 -qscale=3 -init=3

# int8 QK, fp8 V
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=i8fp8bf16 -qscale=3 -init=3
\`\`\`

\`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
2026-04-30 18:33:36 +00:00

96 lines
3.8 KiB
CMake

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 arch is supported
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
if(NOT INST_TARGETS)
message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
# ====================================================================
# SageAttention codegen - only FWD API, minimal instances
# ====================================================================
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
)
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG)
# Only generate FWD API, only supported head dimension (128)
# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py
set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${SAGEATTN_TARGETS_ARG}
--api fwd
--optdim 128
)
# Generate list of kernels to build
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS)
# Generate the kernel instance files
add_custom_command(
OUTPUT ${SAGEATTN_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate SageAttention FWD kernels"
VERBATIM
)
# Build the kernel instances library
add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS})
target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# Compile options for kernel instances
set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
if(CK_USE_OCP_FP8)
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS})
set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON)
# ====================================================================
# SageAttention FWD Example
# ====================================================================
set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd")
message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}")
add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp)
target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# Link with our own minimal instances library (INDEPENDENT from FMHA!)
target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances)
set(SAGEATTN_FWD_COMPILE_OPTIONS)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
if(CK_USE_OCP_FP8)
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS})
set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})