mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)
[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e8d64ad5c6
commit
de0a61e5c2
95
example/ck_tile/49_sageattention/CMakeLists.txt
Normal file
95
example/ck_tile/49_sageattention/CMakeLists.txt
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 arch is supported
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# ====================================================================
|
||||
# SageAttention codegen - only FWD API, minimal instances
|
||||
# ====================================================================
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
)
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG)
|
||||
|
||||
# Only generate FWD API, only supported head dimension (128)
|
||||
# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py
|
||||
set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${SAGEATTN_TARGETS_ARG}
|
||||
--api fwd
|
||||
--optdim 128
|
||||
)
|
||||
|
||||
# Generate list of kernels to build
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS)
|
||||
|
||||
# Generate the kernel instance files
|
||||
add_custom_command(
|
||||
OUTPUT ${SAGEATTN_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate SageAttention FWD kernels"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Build the kernel instances library
|
||||
add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS})
|
||||
target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Compile options for kernel instances
|
||||
set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS})
|
||||
set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# ====================================================================
|
||||
# SageAttention FWD Example
|
||||
# ====================================================================
|
||||
set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}")
|
||||
|
||||
add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Link with our own minimal instances library (INDEPENDENT from FMHA!)
|
||||
target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances)
|
||||
|
||||
set(SAGEATTN_FWD_COMPILE_OPTIONS)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS})
|
||||
set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
2
example/ck_tile/49_sageattention/codegen/__init__.py
Normal file
2
example/ck_tile/49_sageattention/codegen/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
42
example/ck_tile/49_sageattention/codegen/arch.py
Normal file
42
example/ck_tile/49_sageattention/codegen/arch.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArchTrait:
|
||||
name: str
|
||||
preprocessor_check: str = field(default=None)
|
||||
device_name_check: str = field(default=None)
|
||||
tag: str = field(default=None)
|
||||
filename_suffix: str = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.preprocessor_check is None:
|
||||
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
|
||||
if self.device_name_check is None:
|
||||
object.__setattr__(
|
||||
self,
|
||||
"device_name_check",
|
||||
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
|
||||
)
|
||||
if self.tag is None:
|
||||
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
|
||||
if self.filename_suffix is None:
|
||||
object.__setattr__(self, "filename_suffix", f"_{self.name}")
|
||||
|
||||
|
||||
def get_factories_for_targets(
|
||||
targets: List[str], get_factory: Callable[[str], Any]
|
||||
) -> List[Any]:
|
||||
factories = dict()
|
||||
for target in targets:
|
||||
factory = get_factory(target)
|
||||
factories[factory.arch.name] = factory
|
||||
# Place more specific architectures first
|
||||
factories = sorted(
|
||||
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
|
||||
)
|
||||
return factories
|
||||
4
example/ck_tile/49_sageattention/codegen/cmake_config.py
Normal file
4
example/ck_tile/49_sageattention/codegen/cmake_config.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
103
example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py
Normal file
103
example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16": "SageAttentionFwdFp16",
|
||||
"bf16": "SageAttentionFwdBf16",
|
||||
"fp8bf16": "SageAttentionFwdFp8Bf16",
|
||||
"i8fp8bf16": "SageAttentionFwdI8Fp8Bf16",
|
||||
"i4fp8bf16": "SageAttentionFwdI4Fp8Bf16",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no": "SageAttnMasks::NoMask",
|
||||
"causal": "SageAttnMasks::CausalMask",
|
||||
"generic": "SageAttnMasks::GenericMask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_map(mask_impl: str):
|
||||
if mask_impl == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask_impl == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_impl(mask: str) -> str:
|
||||
return "simplified" if mask.startswith("s_") else "generic"
|
||||
|
||||
|
||||
def get_mask_cpp_type(mask: str) -> str:
|
||||
return get_mask_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic": "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no": "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask": "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_check_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
return get_mask_check_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP",
|
||||
"perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"perwarp": "quant_scale_enum::perwarp",
|
||||
"perthread": "quant_scale_enum::perthread",
|
||||
}
|
||||
|
||||
MODE_MAP = {"batch": "false", "group": "true"}
|
||||
|
||||
LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockSageAttentionPipelineQRKSVS",
|
||||
"qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS",
|
||||
"qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
True: "true",
|
||||
False: "false",
|
||||
}
|
||||
2
example/ck_tile/49_sageattention/codegen/ops/__init__.py
Normal file
2
example/ck_tile/49_sageattention/codegen/ops/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
992
example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py
Normal file
992
example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py
Normal file
@@ -0,0 +1,992 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import copy
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, ClassVar, Iterable, List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
MODE_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
get_mask_map,
|
||||
get_mask_cpp_type,
|
||||
get_mask_cpp_check_expr,
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8bf16": 8,
|
||||
"i8fp8bf16": 8,
|
||||
"i4fp8bf16": 4,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32: 32,
|
||||
48: 48,
|
||||
64: 64,
|
||||
80: 96,
|
||||
96: 128,
|
||||
128: 128,
|
||||
192: 192,
|
||||
256: 256,
|
||||
}
|
||||
|
||||
SAGEATTN_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "sageattn_fwd.hpp"
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_KERNEL_BODY_TEMPLATE = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using sageattn_dtype = {F_dtype};
|
||||
|
||||
using sageattn_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using sageattn_shape = ck_tile::TileSageAttnShape<sageattn_block_tile,
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using sageattn_traits = ck_tile::TileSageAttnTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using sageattn_variant = ck_tile::ComposedAttention<false * ck_tile::LOGITS_SOFT_CAP, true>;
|
||||
|
||||
using sageattn_mask_type = {F_mask};
|
||||
|
||||
using sageattn_pipeline_problem = ck_tile::BlockSageAttnPipelineProblem<
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::QDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::KDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::VDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SMPLComputeDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::PDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
|
||||
sageattn_shape,
|
||||
{F_mode},
|
||||
sageattn_variant,
|
||||
sageattn_mask_type,
|
||||
sageattn_traits>;
|
||||
|
||||
using sageattn_pipeline = {F_pipeline}<
|
||||
sageattn_pipeline_problem>;
|
||||
|
||||
using sageattn_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using sageattn_kernel = {F_kernel}<sageattn_pipeline, sageattn_epilogue>;
|
||||
|
||||
|
||||
using trait = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, sageattn_mask_type, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
template<>
|
||||
float sageattn_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, sageattn_fwd_args a)
|
||||
{{
|
||||
using k_ = sageattn_kernel;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", {F_kname}" << std::flush;
|
||||
auto [kargs, grids] = {F_kargs_creator}<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_FILENAME = "sageattn_fwd_api.cpp"
|
||||
SAGEATTN_FWD_API_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "sageattn_fwd.hpp"
|
||||
|
||||
namespace {
|
||||
bool get_num_cus(unsigned& num_cus) {
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}
|
||||
} // namespace
|
||||
"""
|
||||
SAGEATTN_FWD_API_FUNC_TEMPLATE = """
|
||||
namespace {{
|
||||
float {F_func_name}([[maybe_unused]] sageattn_fwd_traits t, [[maybe_unused]] sageattn_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if(!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
}} // namespace
|
||||
"""
|
||||
SAGEATTN_FWD_API_FOOTER_TEMPLATE = """
|
||||
// Public API entry point - unified for SageAttention
|
||||
float sageattn_fwd(sageattn_fwd_traits traits, sageattn_fwd_args args, const ck_tile::stream_config& config) {
|
||||
return sageattn_fwd_impl(traits, args, config);
|
||||
}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return sageattn_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdApiTrait:
|
||||
arch: ArchTrait
|
||||
pipeline_tag: str
|
||||
# sync with sageattn_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
mask: str
|
||||
qscale: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
skip: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.mask}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bm0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
# F_dvpad="f": Causal mask requires hdim_v <= kN1 (num_tile_n1 == 1 for tile reversal)
|
||||
# Non-causal requires hdim_v % kN1 == 0 (epilogue writes full tiles)
|
||||
if self.mask == "causal":
|
||||
return f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v <= {self.bn1})"
|
||||
else:
|
||||
return (
|
||||
f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v % {self.bn1} == 0)"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_qscale: str # no/pertensor/blockscale/perwarp/perthread
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
if self.F_skip == "t":
|
||||
n += "_skip"
|
||||
else:
|
||||
n += "_nskip"
|
||||
|
||||
if self.F_qscale != "no":
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class SageAttnFwdApiPool:
|
||||
def __init__(self):
|
||||
self.pool = OrderedDict()
|
||||
|
||||
def register_traits(self, trait: SageAttnFwdApiTrait) -> None:
|
||||
hdim = trait.hdim, trait.bn1
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
def get_num_traits(
|
||||
self, filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None
|
||||
) -> int:
|
||||
if filter_fn is None:
|
||||
|
||||
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
|
||||
return True
|
||||
|
||||
filter_fn = accept_all
|
||||
|
||||
return sum(
|
||||
sum(1 for trait in pool_by_hdim if filter_fn(trait))
|
||||
for pool_by_arch in self.pool.values()
|
||||
for pool_by_dtype in pool_by_arch.values()
|
||||
for pool_by_hdim in pool_by_dtype.values()
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
func_name,
|
||||
filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None,
|
||||
) -> str:
|
||||
if filter_fn is None:
|
||||
|
||||
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
|
||||
return True
|
||||
|
||||
filter_fn = accept_all
|
||||
|
||||
def has_traits(node) -> bool:
|
||||
"""Recursively traverse nested OrderedDicts and lists to determine if any SageAttnFwdApiTrait satisfies filter_fn()."""
|
||||
if isinstance(node, list):
|
||||
return any(filter_fn(elem) for elem in node)
|
||||
elif isinstance(node, OrderedDict):
|
||||
return any(has_traits(val) for val in node.values())
|
||||
return False
|
||||
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(
|
||||
item for item in self.pool.items() if has_traits(item[1])
|
||||
):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(
|
||||
item for item in pool_by_arch.items() if has_traits(item[1])
|
||||
):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
|
||||
item for item in pool_by_dtype.items() if has_traits(item[1])
|
||||
):
|
||||
max_bm0 = max(
|
||||
(t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0
|
||||
)
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(
|
||||
[trait for trait in pool_by_hdim if filter_fn(trait)]
|
||||
):
|
||||
inners += SAGEATTN_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_mask=get_mask_cpp_type(trait.mask),
|
||||
F_mask_check=get_mask_cpp_check_expr(trait.mask),
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
|
||||
F_qscale=QSCALE_MAP[trait.qscale],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune(max_bm0),
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
per_hdim_case += SAGEATTN_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_hdim_v=hdim_v,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
per_dtypes += SAGEATTN_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
per_arch += SAGEATTN_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
return SAGEATTN_FWD_API_FUNC_TEMPLATE.format(
|
||||
F_func_name=func_name, F_dispatch=indent(per_arch)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdKernel:
|
||||
F_arch: ArchTrait
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: SageAttnFwdTileSize
|
||||
F_pipeline: SageAttnFwdPipeline
|
||||
|
||||
_KERNEL_HEADER: ClassVar[str] = SAGEATTN_FWD_KERNEL_HEADER
|
||||
_KERNEL_BODY_TEMPLATE: ClassVar[str] = SAGEATTN_FWD_KERNEL_BODY_TEMPLATE
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kernel_class_name(cls, pipeline_tag):
|
||||
return "ck_tile::SageAttnFwdKernel"
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
|
||||
return "sageattn_fwd_create_kargs_and_grids"
|
||||
|
||||
def render(self) -> str:
|
||||
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
|
||||
F_kname=self.name,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_cpp_type(self.F_pipeline.F_mask),
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
|
||||
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"sageattn_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> SageAttnFwdApiTrait:
|
||||
return SageAttnFwdApiTrait(
|
||||
arch=self.F_arch,
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
qscale=self.F_pipeline.F_qscale,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemContext:
|
||||
dtype: str
|
||||
mode: str
|
||||
hdim: int
|
||||
hdim_v: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelContext:
|
||||
tile: SageAttnFwdTileSize
|
||||
pipeline: SageAttnFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
|
||||
CompatibilityRule = Callable[[ProblemContext, KernelContext], bool]
|
||||
|
||||
|
||||
def is_compatible(
|
||||
problem_ctx: ProblemContext,
|
||||
kernel_ctx: KernelContext,
|
||||
rules: Iterable[CompatibilityRule],
|
||||
) -> bool:
|
||||
return all(rule(problem_ctx, kernel_ctx) for rule in rules)
|
||||
|
||||
|
||||
def create_kernel(
|
||||
arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> SageAttnFwdKernel:
|
||||
return SageAttnFwdKernel(
|
||||
F_arch=arch,
|
||||
F_dtype=problem_ctx.dtype,
|
||||
F_mode=problem_ctx.mode,
|
||||
F_hdim=problem_ctx.hdim,
|
||||
F_tile=kernel_ctx.tile,
|
||||
F_pipeline=kernel_ctx.pipeline,
|
||||
)
|
||||
|
||||
|
||||
class CompatibilityRuleFactory:
|
||||
@staticmethod
|
||||
def get_rules() -> List[CompatibilityRule]:
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
if problem_ctx.mode == "group":
|
||||
if (
|
||||
kernel_ctx.pipeline.F_spad != "t"
|
||||
or kernel_ctx.pipeline.F_skpad != "t"
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
return [check_mode]
|
||||
|
||||
|
||||
class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = CompatibilityRuleFactory.get_rules()
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
|
||||
pass
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
arch = ArchTrait(
|
||||
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
|
||||
)
|
||||
|
||||
# Note: fp16 is not supported by SageAttention (only bf16 + fp8/int quantization)
|
||||
_DT_BF16 = ("bf16",)
|
||||
_DT_FP8BF16 = ("fp8bf16",)
|
||||
_DT_I8FP8BF16 = ("i8fp8bf16",)
|
||||
_DT_I4FP8BF16 = ("i4fp8bf16",)
|
||||
|
||||
@classmethod
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return cls._DT_BF16 + cls._DT_FP8BF16 + cls._DT_I8FP8BF16 + cls._DT_I4FP8BF16
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if dtype in cls._DT_BF16:
|
||||
return {
|
||||
(128, 128) : [SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# gfx9 (non-gfx950): only F_bn0=64; F_bn0=128 variant is gfx950-only (see Gfx950).
|
||||
return {
|
||||
(128, 128): [
|
||||
SageAttnFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype={dtype}")
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@classmethod
|
||||
def get_pipelines(
|
||||
cls, dtype, hdim, hdim_v, receipt, mask_impl
|
||||
) -> List[SageAttnFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let "t" padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in cls._DT_BF16:
|
||||
qscale = "no"
|
||||
skip = "f" # skip: only false
|
||||
for mask, vlayout in itertools.product(
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["row", "col"],
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "f", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
elif (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# no need lse kernels
|
||||
skip = "f" # skip: only false
|
||||
for mask, qscale, vlayout in itertools.product(
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no", "pertensor", "blockscale", "perwarp", "perthread"],
|
||||
["row", "col"], # Support both row and col major layouts
|
||||
):
|
||||
if dtype in cls._DT_I4FP8BF16:
|
||||
# int4 only uses sync pipeline (qr), pad_d="f" because packed types
|
||||
# require alignment >= PackedSize which conflicts with kPadHeadDimQ=true
|
||||
# forcing alignment to 1. Safe since hdim always matches tile size.
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
elif hdim == 64:
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
|
||||
# Packed types (int4) cannot use head-dim padding: the tile_window infrastructure
|
||||
# forces alignment=1 when padding is enabled, but packed types need alignment >= PackedSize.
|
||||
if dtype in cls._DT_I4FP8BF16:
|
||||
for p in pipelines:
|
||||
assert p.F_dpad == "f", (
|
||||
f"int4 dtype '{dtype}' requires pad_d=false, got '{p.F_dpad}'"
|
||||
)
|
||||
assert p.F_dvpad == "f", (
|
||||
f"int4 dtype '{dtype}' requires pad_dv=false, got '{p.F_dvpad}'"
|
||||
)
|
||||
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx950(
|
||||
KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950
|
||||
):
|
||||
arch = ArchTrait("gfx950")
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# gfx950 fp8-family: F_bn0=128 tile only (gfx9 uses F_bn0=64 only).
|
||||
return {
|
||||
(128, 128): [
|
||||
SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
|
||||
],
|
||||
}
|
||||
return super().get_hdim_tile_size_dict(dtype)
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9):
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
if dtype in cls._DT_BF16:
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(0, SageAttnFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
if os.environ.get("CK_TILE_SAGEATTN_FWD_CUSTOM_FACTORY", "0") == "1":
|
||||
return CustomFactory
|
||||
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx950"):
|
||||
return KernelComponentFactoryGfx950
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Product:
|
||||
name: str
|
||||
rule: CompatibilityRule
|
||||
|
||||
def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
return self.rule(problem_ctx, kernel_ctx)
|
||||
|
||||
|
||||
def get_product(receipt: int) -> Product:
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
# bf16 (no quantization) should not have qscale
|
||||
if problem_ctx.dtype == "bf16":
|
||||
if kernel_ctx.pipeline.F_qscale != "no":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return Product(name="All tiles", rule=fit)
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[SageAttnFwdApiPool, List[SageAttnFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = SageAttnFwdApiPool()
|
||||
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
# for hdim_str, mode, mask, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(
|
||||
d.items(), MODE_MAP.keys()
|
||||
):
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
for tile, next_tile in zip(tiles, tiles[1:]):
|
||||
assert next_tile.F_bm0 >= tile.F_bm0, (
|
||||
"Tiles must be ordered by increasing bm0"
|
||||
)
|
||||
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
problem_ctx = ProblemContext(
|
||||
dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v
|
||||
)
|
||||
kernel_ctx = KernelContext(
|
||||
tile=tile, pipeline=pipeline, mask_impl=mask_impl
|
||||
)
|
||||
rules = factory.get_rules()
|
||||
product = get_product(receipt)
|
||||
|
||||
if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]):
|
||||
continue
|
||||
|
||||
k = create_kernel(factory.arch, problem_ctx, kernel_ctx)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: SageAttnFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.render())
|
||||
|
||||
|
||||
def write_fwd_api(
|
||||
api_pool: SageAttnFwdApiPool,
|
||||
autogen_dir: Path,
|
||||
) -> None:
|
||||
content = "".join(
|
||||
[
|
||||
SAGEATTN_FWD_API_HEADER,
|
||||
api_pool.render("sageattn_fwd_impl"),
|
||||
SAGEATTN_FWD_API_FOOTER_TEMPLATE,
|
||||
]
|
||||
)
|
||||
update_file(autogen_dir / SAGEATTN_FWD_API_FILENAME, content)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write(
|
||||
(file_path.parent / GEN_DIR / SAGEATTN_FWD_API_FILENAME).as_posix() + "\n"
|
||||
)
|
||||
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import dataclasses
|
||||
import os.path as path
|
||||
import textwrap
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def indent(code: str, indent: str = " ") -> str:
|
||||
return textwrap.indent(code, indent)
|
||||
|
||||
|
||||
def if_(i: int) -> str:
|
||||
return "if" if i == 0 else "else if"
|
||||
|
||||
|
||||
def check_duplicates_and_paddings(traits, trait):
|
||||
"""Check
|
||||
* if the traits list does not contain a trait with the same parameters;
|
||||
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
|
||||
for example, f, _t_, f, t cannot be before f, _f_, f, t.
|
||||
"""
|
||||
|
||||
fields = [f.name for f in dataclasses.fields(trait)]
|
||||
pad_fields = [f for f in fields if "pad" in f]
|
||||
non_pad_fields = [f for f in fields if "pad" not in f]
|
||||
for prev_trait in traits:
|
||||
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
|
||||
continue
|
||||
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
|
||||
raise Exception(f"Duplicate found {trait}")
|
||||
# Check if the previous kernel can be incorrectly used before the current one
|
||||
# for example, f, _t_, f, t cannot be before f, _f_, f, t
|
||||
is_prev_more_restrictive = False
|
||||
is_curr_more_restrictive = False
|
||||
for f in pad_fields:
|
||||
prev_pad = getattr(prev_trait, f)
|
||||
pad = getattr(trait, f)
|
||||
if isinstance(prev_pad, str):
|
||||
prev_pad = 1000000 if prev_pad == "f" else 1
|
||||
pad = 1000000 if pad == "f" else 1
|
||||
elif isinstance(prev_pad, int):
|
||||
prev_pad = 1000000 if prev_pad == 0 else prev_pad
|
||||
pad = 1000000 if pad == 0 else pad
|
||||
else:
|
||||
assert False
|
||||
if prev_pad < pad:
|
||||
is_prev_more_restrictive = True
|
||||
elif prev_pad > pad:
|
||||
is_curr_more_restrictive = True
|
||||
if is_prev_more_restrictive and not is_curr_more_restrictive:
|
||||
raise Exception(
|
||||
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
|
||||
)
|
||||
202
example/ck_tile/49_sageattention/example_sageattn_fwd.cpp
Normal file
202
example/ck_tile/49_sageattention/example_sageattn_fwd.cpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "sageattn_fwd.hpp"
|
||||
#include "sageattn_fwd_runner.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k (including new key/value), -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"seqlen_q stride between 2 batches (group-mode optional).\n"
|
||||
"Provide positive strides per-batch to simulate physical padding on Q.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed, same as xformer kv_padding,\n"
|
||||
"must be greater than or equal to s_k")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("qscale",
|
||||
"n",
|
||||
"n or 0, no scale\n"
|
||||
"pt or 1, per-tensor scale\n"
|
||||
"bs or 2, block scale (Q:128, KV:128)\n"
|
||||
"pw or 3, per-warp scale (Q:32, KV:64)\n"
|
||||
"pth or 4, per-thread scale (Q:4, KV:16)\n")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("prec",
|
||||
"fp8bf16",
|
||||
"Primary: fp8bf16, i8fp8bf16, i4fp8bf16. Also bf16 (keep): pipeline validation "
|
||||
"with qscale=n (no quant); not the quantized Sage product path.")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float"
|
||||
"\n tf or 3 - uniform random float, min max is the max of the type\n")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "sageattn_fwd.json", "json file name to dump results")
|
||||
.insert("q_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
auto run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
|
||||
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
float scale_s = arg_parser.get_float("scale_s");
|
||||
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
|
||||
std::string qscale_str = arg_parser.get_str("qscale");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
auto json = arg_parser.get_int("json") == 1
|
||||
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
|
||||
: std::nullopt;
|
||||
|
||||
return sageattn_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
q_eff_lens_per_batch,
|
||||
kv_eff_lens_per_batch,
|
||||
i_perm,
|
||||
o_perm,
|
||||
scale_s,
|
||||
is_v_rowmajor,
|
||||
mask_str,
|
||||
qscale_str,
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run<SageAttentionFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<SageAttentionFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "i8fp8bf16")
|
||||
{
|
||||
return run<SageAttentionFwdI8Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "i4fp8bf16")
|
||||
{
|
||||
return run<SageAttentionFwdI4Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::invalid_argument& e)
|
||||
{
|
||||
std::cerr << "Invalid argument: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
173
example/ck_tile/49_sageattention/generate.py
Normal file
173
example/ck_tile/49_sageattention/generate.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
# Strip "sageattn_" so module sageattn_fwd registers as CLI key "fwd".
|
||||
unwanted_prefix = "sageattn_"
|
||||
handlers = dict(
|
||||
[
|
||||
(
|
||||
(
|
||||
op.__name__[len(unwanted_prefix) :]
|
||||
if op.__name__.startswith(unwanted_prefix)
|
||||
else op.__name__
|
||||
),
|
||||
(op.list_blobs, op.write_blobs),
|
||||
)
|
||||
for op in ops
|
||||
]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir) / GEN_DIR
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
output_file: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="Generate SageAttention CK_tile kernel/API blobs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--targets",
|
||||
default="gfx9,gfx950",
|
||||
required=False,
|
||||
help="list of GPU targets, separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--api",
|
||||
default="fwd",
|
||||
required=False,
|
||||
help="Codegen API key(s), comma-separated (e.g. fwd -> module codegen.ops.sageattn_fwd).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="write all the blobs into a directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
|
||||
)
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default="",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="Codegen receipt index. SageAttention forward currently uses receipt 0 only; "
|
||||
"the value is passed through to ops (see get_product in sageattn_fwd.py).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optdim",
|
||||
default="-1",
|
||||
required=False,
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice. "
|
||||
"e.g. --optdim=32,64,128,256",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
targets = args.targets.split(",")
|
||||
api_list = args.api.split(",")
|
||||
filter_list = args.filter.split(",")
|
||||
filter_list.extend([""] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(
|
||||
targets,
|
||||
args.list_blobs,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
else:
|
||||
write_blobs(
|
||||
targets,
|
||||
args.output_dir,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
169
example/ck_tile/49_sageattention/mask.hpp
Normal file
169
example/ck_tile/49_sageattention/mask.hpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "t(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "b(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
tmp.seqlen_q = seqlen_q;
|
||||
tmp.seqlen_k = seqlen_k;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = std::stoi(v);
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, 0, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
}
|
||||
else if(t == "t" || t == "b" || t == "g")
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, 0, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, 0, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.type = mask_enum::window_generic;
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0")
|
||||
{
|
||||
tmp.type = mask_enum::no_mask;
|
||||
}
|
||||
else if(str == "1" || str == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else if(str == "2" || str == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
std::size_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
return static_cast<std::size_t>(seqlen_q) * seqlen_k;
|
||||
std::size_t area = 0;
|
||||
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
||||
{
|
||||
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
||||
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
||||
if(x_end > x_start)
|
||||
{
|
||||
area += (x_end - x_start);
|
||||
}
|
||||
}
|
||||
return area;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
74
example/ck_tile/49_sageattention/quant.hpp
Normal file
74
example/ck_tile/49_sageattention/quant.hpp
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
|
||||
// keep sync with BlockSageAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
perwarp = 3,
|
||||
perthread = 4,
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
{
|
||||
quant_scale_enum type;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == quant_scale_enum::no_scale)
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
else if(type == quant_scale_enum::perwarp)
|
||||
os << "pw";
|
||||
else if(type == quant_scale_enum::perthread)
|
||||
os << "pth";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
{
|
||||
quant_scale_info info{quant_scale_enum::no_scale};
|
||||
if(str == "n" || str == "0")
|
||||
{
|
||||
info.type = quant_scale_enum::no_scale;
|
||||
}
|
||||
else if(str == "pt" || str == "1")
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else if(str == "bs" || str == "2")
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else if(str == "pw" || str == "3")
|
||||
{
|
||||
info.type = quant_scale_enum::perwarp;
|
||||
}
|
||||
else if(str == "pth" || str == "4")
|
||||
{
|
||||
info.type = quant_scale_enum::perthread;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
|
||||
{
|
||||
qsi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
384
example/ck_tile/49_sageattention/sageattn_fwd.hpp
Normal file
384
example/ck_tile/49_sageattention/sageattn_fwd.hpp
Normal file
@@ -0,0 +1,384 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/sageattn.hpp"
|
||||
|
||||
#include "mask.hpp"
|
||||
#include "quant.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// SageAttention data type configs (must match codegen FWD_DTYPE_MAP + SageAttentionFwdTypeConfig)
|
||||
struct SageAttentionFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdI8Fp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdI4Fp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct SageAttentionFwdTypeConfig;
|
||||
|
||||
// fp16/bf16 are not Sage product dtypes; bf16 is intentionally kept in tile_example_sageattn_fwd
|
||||
// for pipeline validation with qscale=n (no quant).
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using ScaleType = float; // scale type for quantized inputs
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using ScaleType = float; // scale type for quantized inputs
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using ScaleType = float; // scale type for quantized inputs
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdI8Fp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::int8_t;
|
||||
using KDataType = ck_tile::int8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using ScaleType = float; // scale type for Q and K
|
||||
using SaccDataType = float; // Keep as float for softmax computation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // P in FP8 for 2nd gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdI4Fp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::pk_int4_t;
|
||||
using KDataType = ck_tile::pk_int4_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using ScaleType = float;
|
||||
using SaccDataType = float;
|
||||
using SMPLComputeDataType = float;
|
||||
using PDataType = ck_tile::fp8_t;
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct SageAttnMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct sageattn_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* q_descale_ptr;
|
||||
const void* k_descale_ptr;
|
||||
const void* v_descale_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
|
||||
// MQA/GQA..."]
|
||||
//
|
||||
// With padding:
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
||||
// lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
||||
// sequence. [array size: batch]
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
||||
// exclusive. Use one set, not both.
|
||||
//
|
||||
// Batch mode:
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqstart_* and seqlen_* pointers must be nullptr.
|
||||
//
|
||||
// Without padding:
|
||||
// (Note: Physical length equals logical length)
|
||||
//
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
|
||||
// size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
|
||||
//
|
||||
// Batch mode:
|
||||
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
|
||||
//
|
||||
const void* seqstart_q_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqstart_k_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
// BLOCKSCALE parameters
|
||||
ck_tile::index_t nhead_stride_q_descale = 0;
|
||||
ck_tile::index_t nhead_stride_k_descale = 0;
|
||||
ck_tile::index_t nhead_stride_v_descale = 0;
|
||||
ck_tile::index_t batch_stride_q_descale = 0;
|
||||
ck_tile::index_t batch_stride_k_descale = 0;
|
||||
ck_tile::index_t batch_stride_v_descale = 0;
|
||||
ck_tile::index_t block_scale_size_q = 0;
|
||||
ck_tile::index_t block_scale_size_k = 0;
|
||||
const void* block_scale_seqstart_q_ptr = nullptr;
|
||||
const void* block_scale_seqstart_k_ptr = nullptr;
|
||||
};
|
||||
|
||||
template <typename SageAttnKernel>
|
||||
auto sageattn_fwd_create_kargs_and_grids(sageattn_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(SageAttnKernel::kIsGroupMode)
|
||||
{
|
||||
return SageAttnKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_k,
|
||||
args.block_scale_seqstart_q_ptr,
|
||||
args.block_scale_seqstart_k_ptr,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return SageAttnKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_q_descale,
|
||||
args.batch_stride_k_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_k,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(SageAttnKernel::kIsGroupMode)
|
||||
{
|
||||
dim3 grids = SageAttnKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 grids = SageAttnKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
}
|
||||
|
||||
// this is used to pattern-match internal kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockSageAttnPipelineEnum SageAttnPipelineEnum_,
|
||||
typename AttnMask_,
|
||||
ck_tile::BlockSageAttentionQuantScaleEnum QScaleEnum_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct sageattn_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto SageAttnPipelineEnum = SageAttnPipelineEnum_;
|
||||
using AttnMask = ck_tile::remove_cvref_t<AttnMask_>;
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float sageattn_fwd_(const ck_tile::stream_config&, sageattn_fwd_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct sageattn_fwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
quant_scale_enum qscale_type;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float sageattn_fwd(sageattn_fwd_traits, sageattn_fwd_args, const ck_tile::stream_config&);
|
||||
1154
example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp
Normal file
1154
example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp
Normal file
File diff suppressed because it is too large
Load Diff
162
example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
Executable file
162
example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
Executable file
@@ -0,0 +1,162 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# SageAttention forward smoke tests - structure mirrors
|
||||
# example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
#
|
||||
# Run from the ComposableKernel *build* directory (after ninja), same as FMHA:
|
||||
# cd build && ninja tile_example_sageattn_fwd
|
||||
# bash ../example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
|
||||
#
|
||||
# Optional: VERBOSE=1 enables bash -x. CURR_FAILS_FILE / KNOWN_FAILS_FILE override fail logs.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
EXE_NAME=tile_example_sageattn_fwd
|
||||
EXE="$(find . -name "$EXE_NAME" -type f 2>/dev/null | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=${GPU_arch:-}
|
||||
if [ -z "$GPU_arch" ]; then
|
||||
GPU_arch=$(rocminfo 2>/dev/null | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}' || echo "unknown")
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"sageattn_fwd_fails_${GPU_arch}.txt"}
|
||||
rm -f "$CURR_FAILS_FILE"
|
||||
touch "$CURR_FAILS_FILE"
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/sageattn_fwd_known_fails_${GPU_arch}.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then
|
||||
echo "ERROR: $EXE_NAME not found under cwd ($(pwd)). Build with: ninja $EXE_NAME" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_exe() {
|
||||
set +e
|
||||
$EXE "$@"
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ]; then
|
||||
echo "$EXE_NAME $*" >>"$CURR_FAILS_FILE"
|
||||
fi
|
||||
set -e
|
||||
}
|
||||
|
||||
# Core FP8xBF16 cases aligned with FMHA smoke_test_fwd.sh (lines 80-87): batch/group shapes,
|
||||
# masks, GQA, short seqlen, k-only pad. Sweeps blockscale (2) vs per-warp (3) and layouts.
|
||||
run_fp8bf16_smoke() {
|
||||
local qscale
|
||||
local perm
|
||||
for qscale in 2 3; do
|
||||
for perm in 0 1; do
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=2 -h_k=1 -d=128 -d_v=128 -s=55 -s_k=256 \
|
||||
-mask=1
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=3 -d=128 -s=100 -s_k=51 -mask=0
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=99 -s_k=256 \
|
||||
-mask=1
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1024 -s_k=256 \
|
||||
-mask=2
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=3 -s_k=99 -mask=2
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=3 -h=2 -h_k=1 -d=128 -s=200 -s_k=520 \
|
||||
-mask=t:128,30
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -s=99 -s_k=32 -mask=b:4,35
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=33 -s_k=0 -mask=2
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 \
|
||||
-s_kpad=32 -mask=2
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Extra FP8: explicit causal string, xformer window, per-tensor / per-thread quant, V col-major.
|
||||
run_fp8bf16_extras() {
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=t:-1,0
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=1 -operm=1 -vlayout=c -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=2 -h=4 -d=128 -s=256 -s_k=256 -mask=t
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=256 -s_k=256 -mask=xt:64
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=0
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=64 -s_k=64 -mask=0
|
||||
}
|
||||
|
||||
# Group mode + physical padding (same intent as FMHA run_padding_smoke_tests, Sage-only flags).
|
||||
run_group_and_padding_smoke() {
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
|
||||
# group + PERTHREAD: block_scale_seqstart_* must be allocated (same as bs/pw)
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=1 -b=4 -h=8 -h_k=8 -d=128 -s=1024,768,512,256 -s_k=1024,768,512,256 \
|
||||
-mask=0 -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=0 \
|
||||
-q_eff_lens=960,512,384,256 -kv_eff_lens=960,512,384,256
|
||||
}
|
||||
|
||||
# BF16 (no quant): pipeline sanity only; not a shipped Sage mode (see example --help prec).
|
||||
run_bf16_pipeline_smoke() {
|
||||
run_exe -prec=bf16 -init=1 -qscale=n -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
|
||||
run_exe -prec=bf16 -init=1 -qscale=n -iperm=1 -operm=1 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=4 -h_k=1 -d=128 -s=256 -s_k=128 -mask=t:32,32
|
||||
}
|
||||
|
||||
# int8 / int4 x fp8xbf16 (hdim divisible by 8 for int4)
|
||||
run_int_quant_smoke() {
|
||||
run_exe -prec=i8fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
|
||||
run_exe -prec=i4fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=t
|
||||
}
|
||||
|
||||
if [ "${VERBOSE:-0}" = 1 ]; then
|
||||
set -x
|
||||
fi
|
||||
|
||||
run_fp8bf16_smoke
|
||||
run_fp8bf16_extras
|
||||
run_group_and_padding_smoke
|
||||
run_bf16_pipeline_smoke
|
||||
run_int_quant_smoke
|
||||
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f "$KNOWN_FAILS_FILE" ]; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" "$KNOWN_FAILS_FILE"; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$((known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$((new_fails_count + 1))
|
||||
fi
|
||||
done <"$CURR_FAILS_FILE"
|
||||
else
|
||||
new_fails_count=$(wc -l <"$CURR_FAILS_FILE")
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
if [ "$new_fails_count" -gt 0 ]; then
|
||||
cat "$CURR_FAILS_FILE"
|
||||
fi
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $((new_fails_count != 0))
|
||||
254
example/ck_tile/49_sageattention/utils.hpp
Normal file
254
example/ck_tile/49_sageattention/utils.hpp
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
enum class mode_enum
|
||||
{
|
||||
batch = 0,
|
||||
group
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, mode_enum mode)
|
||||
{
|
||||
return stream << (mode == mode_enum::batch ? "batch" : "group");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
inline std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
{
|
||||
std::vector<int32_t> seqstarts = {0};
|
||||
for(int32_t seqlen : seqlens)
|
||||
{
|
||||
seqstarts.push_back(seqstarts.back() + seqlen);
|
||||
}
|
||||
assert(seqstarts.size() == seqlens.size() + 1);
|
||||
return seqstarts;
|
||||
}
|
||||
|
||||
template <typename RandomEngine>
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min, // if not negative, clamp min
|
||||
int32_t seqlen_max, // if not negative, clamp max
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
|
||||
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
|
||||
assert(seqlen_min <= seqlen_max);
|
||||
|
||||
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
|
||||
|
||||
if(mode == mode_enum::group && 1 < count)
|
||||
{
|
||||
using size_type = std::vector<int32_t>::size_type;
|
||||
|
||||
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
|
||||
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
|
||||
|
||||
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
|
||||
auto next_step = std::bind(step_dist, std::ref(random_engine));
|
||||
|
||||
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
|
||||
{
|
||||
const size_type to_decrease = next_idx();
|
||||
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
|
||||
if(seqlens[to_decrease] == seqlen_min)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_type to_increase = (to_decrease + next_step()) % count;
|
||||
|
||||
if(seqlens[to_increase] >= seqlen_max)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
--seqlens[to_decrease];
|
||||
++seqlens[to_increase];
|
||||
}
|
||||
}
|
||||
|
||||
return seqlens;
|
||||
}
|
||||
|
||||
// return random integer generated uniformly in range [low, high]
|
||||
template <typename Int = int, typename RandomEngine>
|
||||
auto randint(Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
return dist(random_engine);
|
||||
}
|
||||
|
||||
// return random integers generated uniformly in range [low, high]
|
||||
template <typename Int, typename ForwardIterator, typename RandomEngine>
|
||||
auto randints(ForwardIterator first,
|
||||
ForwardIterator last,
|
||||
Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
|
||||
std::generate(first, last, [&] { return dist(random_engine); });
|
||||
}
|
||||
|
||||
/*
|
||||
* generate missing values in *_val randomly when the number of values is smaller than batch
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
|
||||
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
|
||||
* q_val=1,2,3,4 -> OK, but ignore exceed one
|
||||
*
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& q_pad_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = q_val[0];
|
||||
ck_tile::index_t k = k_val[0];
|
||||
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = [&] {
|
||||
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
|
||||
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
|
||||
|
||||
if(1 < batch && need_append_kvcache)
|
||||
{
|
||||
// to keep the original s_k value, we always use seqlen_k_max in first batch
|
||||
randints(std::next(seqlen_ks.begin()),
|
||||
seqlen_ks.end(),
|
||||
seqlen_k_min,
|
||||
seqlen_k_max,
|
||||
random_engine);
|
||||
return seqlen_ks;
|
||||
}
|
||||
|
||||
return seqlen_ks;
|
||||
}();
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
std::vector<ck_tile::index_t> s_qpad;
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
ck_tile::index_t q = q_val[idx];
|
||||
ck_tile::index_t k =
|
||||
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
|
||||
ck_tile::index_t kp =
|
||||
k_pad_val.empty()
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
ck_tile::index_t qp =
|
||||
q_pad_val.empty()
|
||||
? -1
|
||||
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
s_qpad.push_back(qp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q =
|
||||
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
|
||||
auto rem_k = generate_seqlens(
|
||||
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
|
||||
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
|
||||
RandomAccessIterator last,
|
||||
Int value,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
std::iota(first, last, value);
|
||||
std::shuffle(first, last, random_engine);
|
||||
}
|
||||
@@ -31,6 +31,7 @@ add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
add_subdirectory(42_mx_gemm)
|
||||
add_subdirectory(49_sageattention)
|
||||
add_subdirectory(50_sparse_attn)
|
||||
add_subdirectory(51_tile_distr_enc_reg_map)
|
||||
if(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS)
|
||||
|
||||
Reference in New Issue
Block a user