mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
75 lines
1.9 KiB
C++
75 lines
1.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <string>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
|
|
|
// keep sync with BlockSageAttentionQuantScaleEnum
|
|
enum class quant_scale_enum
|
|
{
|
|
no_scale = 0,
|
|
pertensor = 1,
|
|
blockscale = 2,
|
|
perwarp = 3,
|
|
perthread = 4,
|
|
};
|
|
|
|
struct quant_scale_info
|
|
{
|
|
quant_scale_enum type;
|
|
|
|
void serialize(std::ostream& os) const
|
|
{
|
|
if(type == quant_scale_enum::no_scale)
|
|
os << "n";
|
|
else if(type == quant_scale_enum::pertensor)
|
|
os << "pt";
|
|
else if(type == quant_scale_enum::blockscale)
|
|
os << "bs";
|
|
else if(type == quant_scale_enum::perwarp)
|
|
os << "pw";
|
|
else if(type == quant_scale_enum::perthread)
|
|
os << "pth";
|
|
}
|
|
|
|
static quant_scale_info decode(std::string str)
|
|
{
|
|
quant_scale_info info{quant_scale_enum::no_scale};
|
|
if(str == "n" || str == "0")
|
|
{
|
|
info.type = quant_scale_enum::no_scale;
|
|
}
|
|
else if(str == "pt" || str == "1")
|
|
{
|
|
info.type = quant_scale_enum::pertensor;
|
|
}
|
|
else if(str == "bs" || str == "2")
|
|
{
|
|
info.type = quant_scale_enum::blockscale;
|
|
}
|
|
else if(str == "pw" || str == "3")
|
|
{
|
|
info.type = quant_scale_enum::perwarp;
|
|
}
|
|
else if(str == "pth" || str == "4")
|
|
{
|
|
info.type = quant_scale_enum::perthread;
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid quant scale value: " + str);
|
|
}
|
|
return info;
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
|
|
{
|
|
qsi.serialize(os);
|
|
return os;
|
|
}
|
|
};
|