mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)
[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e8d64ad5c6
commit
de0a61e5c2
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import dataclasses
|
||||
import os.path as path
|
||||
import textwrap
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def indent(code: str, indent: str = " ") -> str:
|
||||
return textwrap.indent(code, indent)
|
||||
|
||||
|
||||
def if_(i: int) -> str:
|
||||
return "if" if i == 0 else "else if"
|
||||
|
||||
|
||||
def check_duplicates_and_paddings(traits, trait):
|
||||
"""Check
|
||||
* if the traits list does not contain a trait with the same parameters;
|
||||
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
|
||||
for example, f, _t_, f, t cannot be before f, _f_, f, t.
|
||||
"""
|
||||
|
||||
fields = [f.name for f in dataclasses.fields(trait)]
|
||||
pad_fields = [f for f in fields if "pad" in f]
|
||||
non_pad_fields = [f for f in fields if "pad" not in f]
|
||||
for prev_trait in traits:
|
||||
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
|
||||
continue
|
||||
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
|
||||
raise Exception(f"Duplicate found {trait}")
|
||||
# Check if the previous kernel can be incorrectly used before the current one
|
||||
# for example, f, _t_, f, t cannot be before f, _f_, f, t
|
||||
is_prev_more_restrictive = False
|
||||
is_curr_more_restrictive = False
|
||||
for f in pad_fields:
|
||||
prev_pad = getattr(prev_trait, f)
|
||||
pad = getattr(trait, f)
|
||||
if isinstance(prev_pad, str):
|
||||
prev_pad = 1000000 if prev_pad == "f" else 1
|
||||
pad = 1000000 if pad == "f" else 1
|
||||
elif isinstance(prev_pad, int):
|
||||
prev_pad = 1000000 if prev_pad == 0 else prev_pad
|
||||
pad = 1000000 if pad == 0 else pad
|
||||
else:
|
||||
assert False
|
||||
if prev_pad < pad:
|
||||
is_prev_more_restrictive = True
|
||||
elif prev_pad > pad:
|
||||
is_curr_more_restrictive = True
|
||||
if is_prev_more_restrictive and not is_curr_more_restrictive:
|
||||
raise Exception(
|
||||
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
|
||||
)
|
||||
Reference in New Issue
Block a user