From c1bf3f6972e33a98ef3af0b7f921bcccb4d74c82 Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 1 May 2026 02:32:23 +0800 Subject: [PATCH] [CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread --- .../ck_tile/49_sageattention/CMakeLists.txt | 95 ++ .../49_sageattention/codegen/__init__.py | 2 + .../ck_tile/49_sageattention/codegen/arch.py | 42 + .../49_sageattention/codegen/cmake_config.py | 4 + .../codegen/cpp_symbol_map.py | 103 ++ .../49_sageattention/codegen/ops/__init__.py | 2 + .../codegen/ops/sageattn_fwd.py | 992 ++++++++++++++ .../ck_tile/49_sageattention/codegen/utils.py | 70 + .../49_sageattention/example_sageattn_fwd.cpp | 202 +++ example/ck_tile/49_sageattention/generate.py | 173 +++ example/ck_tile/49_sageattention/mask.hpp | 169 +++ example/ck_tile/49_sageattention/quant.hpp | 74 ++ .../ck_tile/49_sageattention/sageattn_fwd.hpp | 384 ++++++ .../49_sageattention/sageattn_fwd_runner.hpp | 1154 +++++++++++++++++ .../script/smoke_test_sageattn_fwd.sh | 162 +++ example/ck_tile/49_sageattention/utils.hpp | 254 ++++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 6 + .../block_sageattention_quant_scale_enum.hpp | 49 + .../kernel/sageattn_fwd_kernel.hpp | 1026 +++++++++++++++ .../pipeline/block_sageattn_pipeline_enum.hpp | 29 + .../block_sageattn_pipeline_problem.hpp | 60 + .../block_sageattn_pipeline_qr_ks_vs.hpp | 861 ++++++++++++ ...block_sageattn_pipeline_qr_ks_vs_async.hpp | 873 +++++++++++++ ...pipeline_qr_ks_vs_async_default_policy.hpp | 18 + ...geattn_pipeline_qr_ks_vs_custom_policy.hpp | 857 ++++++++++++ ...eattn_pipeline_qr_ks_vs_default_policy.hpp | 17 + .../pipeline/tile_sageattn_shape.hpp | 71 + .../pipeline/tile_sageattn_traits.hpp | 42 + include/ck_tile/ops/sageattn.hpp | 17 + 30 files changed, 7809 insertions(+) create mode 100644 example/ck_tile/49_sageattention/CMakeLists.txt create mode 100644 example/ck_tile/49_sageattention/codegen/__init__.py create mode 100644 example/ck_tile/49_sageattention/codegen/arch.py create mode 100644 example/ck_tile/49_sageattention/codegen/cmake_config.py create mode 100644 example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py create mode 100644 example/ck_tile/49_sageattention/codegen/ops/__init__.py create mode 100644 example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py create mode 100644 example/ck_tile/49_sageattention/codegen/utils.py create mode 100644 example/ck_tile/49_sageattention/example_sageattn_fwd.cpp create mode 100644 example/ck_tile/49_sageattention/generate.py create mode 100644 example/ck_tile/49_sageattention/mask.hpp create mode 100644 example/ck_tile/49_sageattention/quant.hpp create mode 100644 example/ck_tile/49_sageattention/sageattn_fwd.hpp create mode 100644 example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp create mode 100755 example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh create mode 100644 example/ck_tile/49_sageattention/utils.hpp create mode 100644 include/ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp create mode 100644 include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp create mode 100644 include/ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp create mode 100644 include/ck_tile/ops/sageattn.hpp diff --git a/example/ck_tile/49_sageattention/CMakeLists.txt b/example/ck_tile/49_sageattention/CMakeLists.txt new file mode 100644 index 0000000000..67671f3cf4 --- /dev/null +++ b/example/ck_tile/49_sageattention/CMakeLists.txt @@ -0,0 +1,95 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# Currently only gfx9 arch is supported +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +if(NOT INST_TARGETS) + message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + +# ==================================================================== +# SageAttention codegen - only FWD API, minimal instances +# ==================================================================== +file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +) +set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") + +list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG) + +# Only generate FWD API, only supported head dimension (128) +# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py +set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --targets ${SAGEATTN_TARGETS_ARG} + --api fwd + --optdim 128 +) + +# Generate list of kernels to build +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS) + +# Generate the kernel instance files +add_custom_command( + OUTPUT ${SAGEATTN_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate SageAttention FWD kernels" + VERBATIM +) + +# Build the kernel instances library +add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS}) +target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +# Compile options for kernel instances +set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS) +list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template) +list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) + +if(CK_USE_OCP_FP8) + list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS}) +set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) +set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON) + +# ==================================================================== +# SageAttention FWD Example +# ==================================================================== +set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd") + +message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}") + +add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp) +target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +# Link with our own minimal instances library (INDEPENDENT from FMHA!) +target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances) + +set(SAGEATTN_FWD_COMPILE_OPTIONS) +list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template) +list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) + +if(CK_USE_OCP_FP8) + list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS}) +set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) diff --git a/example/ck_tile/49_sageattention/codegen/__init__.py b/example/ck_tile/49_sageattention/codegen/__init__.py new file mode 100644 index 0000000000..1df4857184 --- /dev/null +++ b/example/ck_tile/49_sageattention/codegen/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT diff --git a/example/ck_tile/49_sageattention/codegen/arch.py b/example/ck_tile/49_sageattention/codegen/arch.py new file mode 100644 index 0000000000..aeb9a98bbb --- /dev/null +++ b/example/ck_tile/49_sageattention/codegen/arch.py @@ -0,0 +1,42 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass, field +from typing import Any, List, Callable + + +@dataclass(frozen=True) +class ArchTrait: + name: str + preprocessor_check: str = field(default=None) + device_name_check: str = field(default=None) + tag: str = field(default=None) + filename_suffix: str = field(default=None) + + def __post_init__(self): + if self.preprocessor_check is None: + object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)") + if self.device_name_check is None: + object.__setattr__( + self, + "device_name_check", + f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0', + ) + if self.tag is None: + object.__setattr__(self, "tag", f"ck_tile::{self.name}_t") + if self.filename_suffix is None: + object.__setattr__(self, "filename_suffix", f"_{self.name}") + + +def get_factories_for_targets( + targets: List[str], get_factory: Callable[[str], Any] +) -> List[Any]: + factories = dict() + for target in targets: + factory = get_factory(target) + factories[factory.arch.name] = factory + # Place more specific architectures first + factories = sorted( + list(factories.values()), key=lambda f: len(f.arch.name), reverse=True + ) + return factories diff --git a/example/ck_tile/49_sageattention/codegen/cmake_config.py b/example/ck_tile/49_sageattention/codegen/cmake_config.py new file mode 100644 index 0000000000..3399f58947 --- /dev/null +++ b/example/ck_tile/49_sageattention/codegen/cmake_config.py @@ -0,0 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation +GEN_DIR = "" # in Cmake, have to generate files in same folder diff --git a/example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py b/example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py new file mode 100644 index 0000000000..77b0c262fd --- /dev/null +++ b/example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py @@ -0,0 +1,103 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation +FWD_DTYPE_MAP = { + "fp16": "SageAttentionFwdFp16", + "bf16": "SageAttentionFwdBf16", + "fp8bf16": "SageAttentionFwdFp8Bf16", + "i8fp8bf16": "SageAttentionFwdI8Fp8Bf16", + "i4fp8bf16": "SageAttentionFwdI4Fp8Bf16", +} + +_MASK_SIMPLIFIED_MAP = { + "s_no": "ck_tile::SimplifiedGenericAttentionMask", + "s_mask": "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no": "SageAttnMasks::NoMask", + "causal": "SageAttnMasks::CausalMask", + "generic": "SageAttnMasks::GenericMask", +} + + +def get_mask_map(mask_impl: str): + if mask_impl == "generic": + return _MASK_MAP + elif mask_impl == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + + +def get_mask_impl(mask: str) -> str: + return "simplified" if mask.startswith("s_") else "generic" + + +def get_mask_cpp_type(mask: str) -> str: + return get_mask_map(get_mask_impl(mask))[mask] + + +_MASK_CHECK_MAP = { + "no": "t.mask_type == mask_enum::no_mask", + "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic": "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no": "t.mask_type == mask_enum::no_mask", + "s_mask": "t.mask_type != mask_enum::no_mask", +} + + +def get_mask_check_map(mask: str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + + +def get_mask_cpp_check_expr(mask: str) -> str: + return get_mask_check_map(get_mask_impl(mask))[mask] + + +QSCALE_MAP = { + "no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE", + "pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR", + "blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE", + "perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP", + "perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD", +} + +QSCALE_CHECK_MAP = { + "no": "quant_scale_enum::no_scale", + "pertensor": "quant_scale_enum::pertensor", + "blockscale": "quant_scale_enum::blockscale", + "perwarp": "quant_scale_enum::perwarp", + "perthread": "quant_scale_enum::perthread", +} + +MODE_MAP = {"batch": "false", "group": "true"} + +LAYOUT_MAP = {"row": "true", "col": "false"} + +PIPELINE_MAP = { + "qr": "ck_tile::BlockSageAttentionPipelineQRKSVS", + "qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync", +} + +PIPELINE_ENUM_MAP = { + "qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS", + "qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC", +} + +BOOL_MAP = { + "t": "true", + "f": "false", + True: "true", + False: "false", +} diff --git a/example/ck_tile/49_sageattention/codegen/ops/__init__.py b/example/ck_tile/49_sageattention/codegen/ops/__init__.py new file mode 100644 index 0000000000..1df4857184 --- /dev/null +++ b/example/ck_tile/49_sageattention/codegen/ops/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT diff --git a/example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py b/example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py new file mode 100644 index 0000000000..8956594090 --- /dev/null +++ b/example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py @@ -0,0 +1,992 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation +import copy +import fnmatch +import itertools +import os +from collections import OrderedDict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, ClassVar, Iterable, List, Optional, Tuple + +from codegen.arch import ArchTrait, get_factories_for_targets +from codegen.cmake_config import GEN_DIR +from codegen.cpp_symbol_map import ( + LAYOUT_MAP, + BOOL_MAP, + PIPELINE_MAP, + PIPELINE_ENUM_MAP, + MODE_MAP, + FWD_DTYPE_MAP, + get_mask_map, + get_mask_cpp_type, + get_mask_cpp_check_expr, + QSCALE_CHECK_MAP, + QSCALE_MAP, +) +from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file + +DTYPE_BITS = { + "fp16": 16, + "bf16": 16, + "fp8bf16": 8, + "i8fp8bf16": 8, + "i4fp8bf16": 4, +} + +K0_MAX_SUBMAX_MAP = { + 32: 32, + 48: 48, + 64: 64, + 80: 96, + 96: 128, + 128: 128, + 192: 192, + 256: 256, +} + +SAGEATTN_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp" +#include "sageattn_fwd.hpp" +""" + +SAGEATTN_FWD_KERNEL_BODY_TEMPLATE = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + +using sageattn_dtype = {F_dtype}; + +using sageattn_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using sageattn_shape = ck_tile::TileSageAttnShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using sageattn_traits = ck_tile::TileSageAttnTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_qscale}, + {F_occupancy}, + {F_skip}>; + +using sageattn_variant = ck_tile::ComposedAttention; + +using sageattn_mask_type = {F_mask}; + +using sageattn_pipeline_problem = ck_tile::BlockSageAttnPipelineProblem< + typename SageAttentionFwdTypeConfig::QDataType, + typename SageAttentionFwdTypeConfig::KDataType, + typename SageAttentionFwdTypeConfig::VDataType, + typename SageAttentionFwdTypeConfig::SaccDataType, + typename SageAttentionFwdTypeConfig::SMPLComputeDataType, + typename SageAttentionFwdTypeConfig::PDataType, + typename SageAttentionFwdTypeConfig::OaccDataType, + typename SageAttentionFwdTypeConfig::ODataType, + sageattn_shape, + {F_mode}, + sageattn_variant, + sageattn_mask_type, + sageattn_traits>; + +using sageattn_pipeline = {F_pipeline}< + sageattn_pipeline_problem>; + +using sageattn_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename SageAttentionFwdTypeConfig::ODataType, + {F_spad}, {F_dvpad}>>; + +using sageattn_kernel = {F_kernel}; + + +using trait = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, sageattn_mask_type, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + +template<> +float sageattn_fwd_(const ck_tile::stream_config& s, sageattn_fwd_args a) +{{ + using k_ = sageattn_kernel; + if(s.log_level_ > 0) + std::cout << ", {F_kname}" << std::flush; + auto [kargs, grids] = {F_kargs_creator}(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) +""" + +SAGEATTN_FWD_API_FILENAME = "sageattn_fwd_api.cpp" +SAGEATTN_FWD_API_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include + +#include + +#include "sageattn_fwd.hpp" + +namespace { +bool get_num_cus(unsigned& num_cus) { + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) { + fprintf(stderr, "failed to get device"); + return false; + } + + hipDeviceProp_t props{}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) { + fprintf(stderr, "failed to get device properties"); + return false; + } + + num_cus = props.multiProcessorCount; + return true; +} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) { + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +} +} // namespace +""" +SAGEATTN_FWD_API_FUNC_TEMPLATE = """ +namespace {{ +float {F_func_name}([[maybe_unused]] sageattn_fwd_traits t, [[maybe_unused]] sageattn_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if(!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); + +{F_dispatch} + return r; +}} +}} // namespace +""" +SAGEATTN_FWD_API_FOOTER_TEMPLATE = """ +// Public API entry point - unified for SageAttention +float sageattn_fwd(sageattn_fwd_traits traits, sageattn_fwd_args args, const ck_tile::stream_config& config) { + return sageattn_fwd_impl(traits, args, config); +} +""" + +SAGEATTN_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ +{F_dtype_case} +}} +""" + +SAGEATTN_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{ +{F_hdim_case} +}} +""" + +SAGEATTN_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} +}} +""" + +SAGEATTN_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + return sageattn_fwd_(s, a); +}} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class SageAttnFwdApiTrait: + arch: ArchTrait + pipeline_tag: str + # sync with sageattn_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + mask: str + qscale: str # + spad: str + skpad: str + dpad: str + dvpad: str + skip: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.mask}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + return "true" + elif self.pipeline_tag in ["qr", "qs"]: + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_q % {self.bm0} == 0" + else: + assert False + + def seqtune(self, max_bm0: int) -> str: + if self.bm0 == max_bm0 or self.bm0 == 64: + return "true/*fall back to largest tile*/" + else: + return f"a.seqlen_q <= {self.bm0}" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.skpad == "t": + return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" + else: + return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + elif self.pipeline_tag in ["qr", "qs"]: + if self.skpad == "t": + return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + else: + assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == "qr_async": + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs"]: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == "t": + return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {bk0submax} == 0" + else: + assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == "qr_async": + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs"]: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == "t": + return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + # F_dvpad="f": Causal mask requires hdim_v <= kN1 (num_tile_n1 == 1 for tile reversal) + # Non-causal requires hdim_v % kN1 == 0 (epilogue writes full tiles) + if self.mask == "causal": + return f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v <= {self.bn1})" + else: + return ( + f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v % {self.bn1} == 0)" + ) + else: + assert False + + +@dataclass +class SageAttnFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_qscale: str # no/pertensor/blockscale/perwarp/perthread + F_mask: str # value from MASK_MAP + F_skip: str # true/false + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + if self.F_skip == "t": + n += "_skip" + else: + n += "_nskip" + + if self.F_qscale != "no": + n += f"_{self.F_qscale}" + else: + n += "_nqscale" + + return n + + +class SageAttnFwdApiPool: + def __init__(self): + self.pool = OrderedDict() + + def register_traits(self, trait: SageAttnFwdApiTrait) -> None: + hdim = trait.hdim, trait.bn1 + ts = ( + self.pool.setdefault(trait.arch, OrderedDict()) + .setdefault(trait.dtype, OrderedDict()) + .setdefault(hdim, []) + ) + check_duplicates_and_paddings(ts, trait) + ts.append(copy.copy(trait)) + + def get_num_traits( + self, filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None + ) -> int: + if filter_fn is None: + + def accept_all(trait: SageAttnFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + return sum( + sum(1 for trait in pool_by_hdim if filter_fn(trait)) + for pool_by_arch in self.pool.values() + for pool_by_dtype in pool_by_arch.values() + for pool_by_hdim in pool_by_dtype.values() + ) + + def render( + self, + func_name, + filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None, + ) -> str: + if filter_fn is None: + + def accept_all(trait: SageAttnFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + def has_traits(node) -> bool: + """Recursively traverse nested OrderedDicts and lists to determine if any SageAttnFwdApiTrait satisfies filter_fn().""" + if isinstance(node, list): + return any(filter_fn(elem) for elem in node) + elif isinstance(node, OrderedDict): + return any(has_traits(val) for val in node.values()) + return False + + per_arch = str() + for i_arch, (arch, pool_by_arch) in enumerate( + item for item in self.pool.items() if has_traits(item[1]) + ): + per_dtypes = str() + for i_dtype, (dtype, pool_by_dtype) in enumerate( + item for item in pool_by_arch.items() if has_traits(item[1]) + ): + per_hdim_case = str() + for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( + item for item in pool_by_dtype.items() if has_traits(item[1]) + ): + max_bm0 = max( + (t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0 + ) + inners = str() + for i_trait, trait in enumerate( + [trait for trait in pool_by_hdim if filter_fn(trait)] + ): + inners += SAGEATTN_FWD_API_INNER_DISPATCH.format( + F_if=if_(i_trait), + F_arch=arch, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_cpp_type(trait.mask), + F_mask_check=get_mask_cpp_check_expr(trait.mask), + F_skip=BOOL_MAP[trait.skip], + F_qscale_check=QSCALE_CHECK_MAP[trait.qscale], + F_qscale=QSCALE_MAP[trait.qscale], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune(max_bm0), + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + per_hdim_case += SAGEATTN_FWD_API_PER_HDIM_CASE.format( + F_if=if_(i_hdim), + F_hdim=hdim, + F_hdim_v=hdim_v, + F_inner_dispatch=indent(inners), + ) + per_dtypes += SAGEATTN_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case) + ) + per_arch += SAGEATTN_FWD_API_PER_ARCH.format( + F_if=if_(i_arch), + F_arch=arch, + F_dtype_case=indent(per_dtypes), + ) + return SAGEATTN_FWD_API_FUNC_TEMPLATE.format( + F_func_name=func_name, F_dispatch=indent(per_arch) + ) + + +@dataclass +class SageAttnFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class SageAttnFwdKernel: + F_arch: ArchTrait + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: SageAttnFwdTileSize + F_pipeline: SageAttnFwdPipeline + + _KERNEL_HEADER: ClassVar[str] = SAGEATTN_FWD_KERNEL_HEADER + _KERNEL_BODY_TEMPLATE: ClassVar[str] = SAGEATTN_FWD_KERNEL_BODY_TEMPLATE + + @classmethod + def _get_cpp_kernel_class_name(cls, pipeline_tag): + return "ck_tile::SageAttnFwdKernel" + + @classmethod + def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): + return "sageattn_fwd_create_kargs_and_grids" + + def render(self) -> str: + return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( + F_kname=self.name, + F_arch=self.F_arch, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale], + F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_cpp_type(self.F_pipeline.F_mask), + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), + F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"sageattn_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return f"{self.name}{self.F_arch.filename_suffix}.cpp" + + def api_trait(self) -> SageAttnFwdApiTrait: + return SageAttnFwdApiTrait( + arch=self.F_arch, + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + qscale=self.F_pipeline.F_qscale, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +@dataclass +class ProblemContext: + dtype: str + mode: str + hdim: int + hdim_v: int + + +@dataclass +class KernelContext: + tile: SageAttnFwdTileSize + pipeline: SageAttnFwdPipeline + mask_impl: str + + +CompatibilityRule = Callable[[ProblemContext, KernelContext], bool] + + +def is_compatible( + problem_ctx: ProblemContext, + kernel_ctx: KernelContext, + rules: Iterable[CompatibilityRule], +) -> bool: + return all(rule(problem_ctx, kernel_ctx) for rule in rules) + + +def create_kernel( + arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext +) -> SageAttnFwdKernel: + return SageAttnFwdKernel( + F_arch=arch, + F_dtype=problem_ctx.dtype, + F_mode=problem_ctx.mode, + F_hdim=problem_ctx.hdim, + F_tile=kernel_ctx.tile, + F_pipeline=kernel_ctx.pipeline, + ) + + +class CompatibilityRuleFactory: + @staticmethod + def get_rules() -> List[CompatibilityRule]: + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + if problem_ctx.mode == "group": + if ( + kernel_ctx.pipeline.F_spad != "t" + or kernel_ctx.pipeline.F_skpad != "t" + ): + return False + return True + + return [check_mode] + + +class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): + _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) + + @classmethod + def get_rules(cls) -> List[CompatibilityRule]: + rules = CompatibilityRuleFactory.get_rules() + + return rules + + +class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): + pass + + +class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): + arch = ArchTrait( + "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" + ) + + # Note: fp16 is not supported by SageAttention (only bf16 + fp8/int quantization) + _DT_BF16 = ("bf16",) + _DT_FP8BF16 = ("fp8bf16",) + _DT_I8FP8BF16 = ("i8fp8bf16",) + _DT_I4FP8BF16 = ("i4fp8bf16",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_BF16 + cls._DT_FP8BF16 + cls._DT_I8FP8BF16 + cls._DT_I4FP8BF16 + + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_BF16: + return { + (128, 128) : [SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } # fmt: skip + elif ( + dtype in cls._DT_FP8BF16 + or dtype in cls._DT_I8FP8BF16 + or dtype in cls._DT_I4FP8BF16 + ): + # gfx9 (non-gfx950): only F_bn0=64; F_bn0=128 variant is gfx950-only (see Gfx950). + return { + (128, 128): [ + SageAttnFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip + ], + } + else: + raise ValueError(f"unsupported dtype={dtype}") + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[SageAttnFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let "t" padding to appear later!! + # TODO: how to design this more generic? + pipelines = [] + if dtype in cls._DT_BF16: + qscale = "no" + skip = "f" # skip: only false + for mask, vlayout in itertools.product( + get_mask_map(mask_impl).keys(), + ["row", "col"], + ): + if hdim == 256 and hdim_v == 256: + pipelines.append(SageAttnFwdPipeline("qr", vlayout, "f", "f", "f", "f", qscale, mask, skip)) # fmt: skip + # the below two is used for hdim vectorize load + pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip + pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip + else: + pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip + pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip + elif ( + dtype in cls._DT_FP8BF16 + or dtype in cls._DT_I8FP8BF16 + or dtype in cls._DT_I4FP8BF16 + ): + # no need lse kernels + skip = "f" # skip: only false + for mask, qscale, vlayout in itertools.product( + get_mask_map(mask_impl).keys(), + ["no", "pertensor", "blockscale", "perwarp", "perthread"], + ["row", "col"], # Support both row and col major layouts + ): + if dtype in cls._DT_I4FP8BF16: + # int4 only uses sync pipeline (qr), pad_d="f" because packed types + # require alignment >= PackedSize which conflicts with kPadHeadDimQ=true + # forcing alignment to 1. Safe since hdim always matches tile size. + pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip + pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip + elif hdim == 64: + pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip + pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip + else: + pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip + pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip + + # Packed types (int4) cannot use head-dim padding: the tile_window infrastructure + # forces alignment=1 when padding is enabled, but packed types need alignment >= PackedSize. + if dtype in cls._DT_I4FP8BF16: + for p in pipelines: + assert p.F_dpad == "f", ( + f"int4 dtype '{dtype}' requires pad_d=false, got '{p.F_dpad}'" + ) + assert p.F_dvpad == "f", ( + f"int4 dtype '{dtype}' requires pad_dv=false, got '{p.F_dvpad}'" + ) + + return pipelines + + +class KernelComponentFactoryGfx950( + KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950 +): + arch = ArchTrait("gfx950") + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if ( + dtype in cls._DT_FP8BF16 + or dtype in cls._DT_I8FP8BF16 + or dtype in cls._DT_I4FP8BF16 + ): + # gfx950 fp8-family: F_bn0=128 tile only (gfx9 uses F_bn0=64 only). + return { + (128, 128): [ + SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip + ], + } + return super().get_hdim_tile_size_dict(dtype) + + +class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_BF16: + if (128, 128) in result.keys(): + result[(128, 128)].insert(0, SageAttnFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip + return result + + +def get_factory(target: str): + if os.environ.get("CK_TILE_SAGEATTN_FWD_CUSTOM_FACTORY", "0") == "1": + return CustomFactory + + # Place more specific architectures first + + if target.startswith("gfx950"): + return KernelComponentFactoryGfx950 + if target.startswith("gfx9"): + return KernelComponentFactoryGfx9 + + raise Exception(f"Unsupported device target {target}") + + +@dataclass(frozen=True) +class Product: + name: str + rule: CompatibilityRule + + def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return self.rule(problem_ctx, kernel_ctx) + + +def get_product(receipt: int) -> Product: + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + # bf16 (no quantization) should not have qscale + if problem_ctx.dtype == "bf16": + if kernel_ctx.pipeline.F_qscale != "no": + return False + + return True + + return Product(name="All tiles", rule=fit) + + +def get_fwd_blobs( + targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[SageAttnFwdApiPool, List[SageAttnFwdKernel]]: + gen = list() + api_pool = SageAttnFwdApiPool() + + factories = get_factories_for_targets(targets, get_factory) + + for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()): + d = factory.get_hdim_tile_size_dict(dtype) + # for hdim_str, mode, mask, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for ((hdim, hdim_v), tiles), mode in itertools.product( + d.items(), MODE_MAP.keys() + ): + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + for tile, next_tile in zip(tiles, tiles[1:]): + assert next_tile.F_bm0 >= tile.F_bm0, ( + "Tiles must be ordered by increasing bm0" + ) + + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + problem_ctx = ProblemContext( + dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v + ) + kernel_ctx = KernelContext( + tile=tile, pipeline=pipeline, mask_impl=mask_impl + ) + rules = factory.get_rules() + product = get_product(receipt) + + if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]): + continue + + k = create_kernel(factory.arch, problem_ctx, kernel_ctx) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: SageAttnFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.render()) + + +def write_fwd_api( + api_pool: SageAttnFwdApiPool, + autogen_dir: Path, +) -> None: + content = "".join( + [ + SAGEATTN_FWD_API_HEADER, + api_pool.render("sageattn_fwd_impl"), + SAGEATTN_FWD_API_FOOTER_TEMPLATE, + ] + ) + update_file(autogen_dir / SAGEATTN_FWD_API_FILENAME, content) + + +def write_blobs( + targets: List[str], + output_dir: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, +) -> None: + api_pool, kernels = get_fwd_blobs( + targets, kernel_filter, receipt, optdim_list, mask_impl + ) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + targets: List[str], + file_path: Path, + kernel_filter: str, + receipt, + optdim_list, + mask_impl, +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs( + targets, kernel_filter, receipt, optdim_list, mask_impl + ) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write( + (file_path.parent / GEN_DIR / SAGEATTN_FWD_API_FILENAME).as_posix() + "\n" + ) diff --git a/example/ck_tile/49_sageattention/codegen/utils.py b/example/ck_tile/49_sageattention/codegen/utils.py new file mode 100644 index 0000000000..3fefe73ad9 --- /dev/null +++ b/example/ck_tile/49_sageattention/codegen/utils.py @@ -0,0 +1,70 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation +import dataclasses +import os.path as path +import textwrap + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +def indent(code: str, indent: str = " ") -> str: + return textwrap.indent(code, indent) + + +def if_(i: int) -> str: + return "if" if i == 0 else "else if" + + +def check_duplicates_and_paddings(traits, trait): + """Check + * if the traits list does not contain a trait with the same parameters; + * if paddings are consitent: the previous kernel can be incorrectly called before the new one, + for example, f, _t_, f, t cannot be before f, _f_, f, t. + """ + + fields = [f.name for f in dataclasses.fields(trait)] + pad_fields = [f for f in fields if "pad" in f] + non_pad_fields = [f for f in fields if "pad" not in f] + for prev_trait in traits: + if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields): + continue + if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields): + raise Exception(f"Duplicate found {trait}") + # Check if the previous kernel can be incorrectly used before the current one + # for example, f, _t_, f, t cannot be before f, _f_, f, t + is_prev_more_restrictive = False + is_curr_more_restrictive = False + for f in pad_fields: + prev_pad = getattr(prev_trait, f) + pad = getattr(trait, f) + if isinstance(prev_pad, str): + prev_pad = 1000000 if prev_pad == "f" else 1 + pad = 1000000 if pad == "f" else 1 + elif isinstance(prev_pad, int): + prev_pad = 1000000 if prev_pad == 0 else prev_pad + pad = 1000000 if pad == 0 else pad + else: + assert False + if prev_pad < pad: + is_prev_more_restrictive = True + elif prev_pad > pad: + is_curr_more_restrictive = True + if is_prev_more_restrictive and not is_curr_more_restrictive: + raise Exception( + f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}" + ) diff --git a/example/ck_tile/49_sageattention/example_sageattn_fwd.cpp b/example/ck_tile/49_sageattention/example_sageattn_fwd.cpp new file mode 100644 index 0000000000..3ef85d39db --- /dev/null +++ b/example/ck_tile/49_sageattention/example_sageattn_fwd.cpp @@ -0,0 +1,202 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "sageattn_fwd.hpp" +#include "sageattn_fwd_runner.hpp" + +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch " + "(group mode)") + .insert("s_k", + "-1", + "seqlen_k (including new key/value), -1 means equal to s\n" + "also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch " + "(group mode)") + .insert("s_qpad", + "-1", + "seqlen_q stride between 2 batches (group-mode optional).\n" + "Provide positive strides per-batch to simulate physical padding on Q.") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 batches, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed, same as xformer kv_padding,\n" + "must be greater than or equal to s_k") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") + .insert("qscale", + "n", + "n or 0, no scale\n" + "pt or 1, per-tensor scale\n" + "bs or 2, block scale (Q:128, KV:128)\n" + "pw or 3, per-warp scale (Q:32, KV:64)\n" + "pth or 4, per-thread scale (Q:4, KV:16)\n") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("prec", + "fp8bf16", + "Primary: fp8bf16, i8fp8bf16, i4fp8bf16. Also bf16 (keep): pipeline validation " + "with qscale=n (no quant); not the quantized Sage product path.") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", + "uf", + "init method:\n ui or 0 - uniform random int\n ni - normalized random int" + "\n uf or 1 - uniform random float\n nf - normalized random float" + "\n tf or 2 - trig float" + "\n tf or 3 - uniform random float, min max is the max of the type\n") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "sageattn_fwd.json", "json file name to dump results") + .insert("q_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +auto run(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + mode_enum mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + auto seqlen_qs = arg_parser.get_int_vec("s"); + auto seqlen_ks = arg_parser.get_int_vec("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); + auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); + auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + float scale_s = arg_parser.get_float("scale_s"); + bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r"; + std::string qscale_str = arg_parser.get_str("qscale"); + std::string mask_str = arg_parser.get_str("mask"); + std::string init_method = arg_parser.get_str("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0), + arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_str("timer") == std::string("gpu")}; + + auto json = arg_parser.get_int("json") == 1 + ? std::optional{arg_parser.get_str("jsonfile")} + : std::nullopt; + + return sageattn_fwd_run(mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + hdim_q, + hdim_v, + seqlen_qpads, + seqlen_kpads, + q_eff_lens_per_batch, + kv_eff_lens_per_batch, + i_perm, + o_perm, + scale_s, + is_v_rowmajor, + mask_str, + qscale_str, + init_method, + seed, + do_validation, + stream_config, + json); +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "i8fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "i4fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + std::cerr << "Unsupported precision: " << data_type << std::endl; + return -1; + } + catch(const std::invalid_argument& e) + { + std::cerr << "Invalid argument: " << e.what() << std::endl; + return -1; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << std::endl; + return -2; + } +} diff --git a/example/ck_tile/49_sageattention/generate.py b/example/ck_tile/49_sageattention/generate.py new file mode 100644 index 0000000000..c2d011869c --- /dev/null +++ b/example/ck_tile/49_sageattention/generate.py @@ -0,0 +1,173 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import pkgutil +from typing import List, Optional + +import codegen.ops +from codegen.cmake_config import GEN_DIR + + +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 + + +# inspect all modules under 'codegen.ops' and register API handlers +ops = [] +for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): + full_module_name = "%s.%s" % (codegen.ops.__name__, module_name) + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) +# Strip "sageattn_" so module sageattn_fwd registers as CLI key "fwd". +unwanted_prefix = "sageattn_" +handlers = dict( + [ + ( + ( + op.__name__[len(unwanted_prefix) :] + if op.__name__.startswith(unwanted_prefix) + else op.__name__ + ), + (op.list_blobs, op.write_blobs), + ) + for op in ops + ] +) +assert 0 < len(handlers) + + +def write_blobs( + targets: List[str], + output_dir: Optional[str], + api_list: List[str], + filters_list: List[str], + optdim_list: List[int], + receipt, + mask_impl, +) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + + for api, kernel_filter in zip(api_list, filters_list): + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl) + + +# list all the files that will be generated +def list_blobs( + targets: List[str], + output_file: Optional[str], + api_list: List[str], + filters_list: List[str], + optdim_list: List[int], + receipt, + mask_impl, +) -> None: + assert output_file is not None + file_path = Path(output_file) + + # create an empty file / drop its contents if it exists + open(file_path, "w").close() + + for api, kernel_filter in zip(api_list, filters_list): + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="Generate SageAttention CK_tile kernel/API blobs.", + ) + parser.add_argument( + "--targets", + default="gfx9,gfx950", + required=False, + help="list of GPU targets, separated by comma.", + ) + parser.add_argument( + "-a", + "--api", + default="fwd", + required=False, + help="Codegen API key(s), comma-separated (e.g. fwd -> module codegen.ops.sageattn_fwd).", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory", + ) + parser.add_argument( + "-l", "--list_blobs", required=False, help="list all the kernels to a file" + ) + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + default="", + required=False, + help="filter out kernels that need to generate, using fnmatch module", + ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic", + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="Codegen receipt index. SageAttention forward currently uses receipt 0 only; " + "the value is passed through to ops (see get_product in sageattn_fwd.py).", + ) + + parser.add_argument( + "--optdim", + default="-1", + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice. " + "e.g. --optdim=32,64,128,256", + ) + + args = parser.parse_args() + targets = args.targets.split(",") + api_list = args.api.split(",") + filter_list = args.filter.split(",") + filter_list.extend([""] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(",")] + + if args.list_blobs is not None: + list_blobs( + targets, + args.list_blobs, + api_list, + filter_list, + optdim_list, + int(args.receipt), + mask_impl=args.mask, + ) + else: + write_blobs( + targets, + args.output_dir, + api_list, + filter_list, + optdim_list, + int(args.receipt), + mask_impl=args.mask, + ) diff --git a/example/ck_tile/49_sageattention/mask.hpp b/example/ck_tile/49_sageattention/mask.hpp new file mode 100644 index 0000000000..9d3da2fb8f --- /dev/null +++ b/example/ck_tile/49_sageattention/mask.hpp @@ -0,0 +1,169 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + tmp.seqlen_q = seqlen_q; + tmp.seqlen_k = seqlen_k; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = std::stoi(v); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, 0, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else if(t == "t" || t == "b" || t == "g") + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + throw std::invalid_argument("invalid mask value: " + str); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, 0, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, 0, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.type = mask_enum::window_generic; + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + } + else + { + throw std::invalid_argument("invalid mask value: " + str); + } + } + else if(str == "0") + { + tmp.type = mask_enum::no_mask; + } + else if(str == "1" || str == "t") + { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + } + else if(str == "2" || str == "b") + { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + } + else + { + throw std::invalid_argument("invalid mask value: " + str); + } + return tmp; + } + + std::size_t get_unmaskarea() const + { + if(type == mask_enum::no_mask) + return static_cast(seqlen_q) * seqlen_k; + std::size_t area = 0; + for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) + { + ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/49_sageattention/quant.hpp b/example/ck_tile/49_sageattention/quant.hpp new file mode 100644 index 0000000000..b44149d3b6 --- /dev/null +++ b/example/ck_tile/49_sageattention/quant.hpp @@ -0,0 +1,74 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp" + +// keep sync with BlockSageAttentionQuantScaleEnum +enum class quant_scale_enum +{ + no_scale = 0, + pertensor = 1, + blockscale = 2, + perwarp = 3, + perthread = 4, +}; + +struct quant_scale_info +{ + quant_scale_enum type; + + void serialize(std::ostream& os) const + { + if(type == quant_scale_enum::no_scale) + os << "n"; + else if(type == quant_scale_enum::pertensor) + os << "pt"; + else if(type == quant_scale_enum::blockscale) + os << "bs"; + else if(type == quant_scale_enum::perwarp) + os << "pw"; + else if(type == quant_scale_enum::perthread) + os << "pth"; + } + + static quant_scale_info decode(std::string str) + { + quant_scale_info info{quant_scale_enum::no_scale}; + if(str == "n" || str == "0") + { + info.type = quant_scale_enum::no_scale; + } + else if(str == "pt" || str == "1") + { + info.type = quant_scale_enum::pertensor; + } + else if(str == "bs" || str == "2") + { + info.type = quant_scale_enum::blockscale; + } + else if(str == "pw" || str == "3") + { + info.type = quant_scale_enum::perwarp; + } + else if(str == "pth" || str == "4") + { + info.type = quant_scale_enum::perthread; + } + else + { + throw std::invalid_argument("invalid quant scale value: " + str); + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi) + { + qsi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/49_sageattention/sageattn_fwd.hpp b/example/ck_tile/49_sageattention/sageattn_fwd.hpp new file mode 100644 index 0000000000..04a630f081 --- /dev/null +++ b/example/ck_tile/49_sageattention/sageattn_fwd.hpp @@ -0,0 +1,384 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/sageattn.hpp" + +#include "mask.hpp" +#include "quant.hpp" + +#include +#include +#include + +// SageAttention data type configs (must match codegen FWD_DTYPE_MAP + SageAttentionFwdTypeConfig) +struct SageAttentionFwdFp16 +{ +}; + +struct SageAttentionFwdBf16 +{ +}; + +struct SageAttentionFwdFp8Bf16 +{ +}; + +struct SageAttentionFwdI8Fp8Bf16 +{ +}; + +struct SageAttentionFwdI4Fp8Bf16 +{ +}; + +template +struct SageAttentionFwdTypeConfig; + +// fp16/bf16 are not Sage product dtypes; bf16 is intentionally kept in tile_example_sageattn_fwd +// for pipeline validation with qscale=n (no quant). +template <> +struct SageAttentionFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using ScaleType = float; // scale type for quantized inputs + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct SageAttentionFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using ScaleType = float; // scale type for quantized inputs + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct SageAttentionFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using ScaleType = float; // scale type for quantized inputs + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct SageAttentionFwdTypeConfig +{ + using QDataType = ck_tile::int8_t; + using KDataType = ck_tile::int8_t; + using VDataType = ck_tile::fp8_t; + using ScaleType = float; // scale type for Q and K + using SaccDataType = float; // Keep as float for softmax computation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // P in FP8 for 2nd gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct SageAttentionFwdTypeConfig +{ + using QDataType = ck_tile::pk_int4_t; + using KDataType = ck_tile::pk_int4_t; + using VDataType = ck_tile::fp8_t; + using ScaleType = float; + using SaccDataType = float; + using SMPLComputeDataType = float; + using PDataType = ck_tile::fp8_t; + using OaccDataType = float; + using ODataType = ck_tile::bf16_t; +}; + +struct SageAttnMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct sageattn_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* q_descale_ptr; + const void* k_descale_ptr; + const void* v_descale_ptr; + void* o_ptr; + + // Usage notes for sequence length pointer parameters: + // + // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles + // MQA/GQA..."] + // + // With padding: + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence + // lengths. [array size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each + // sequence. [array size: batch] + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually + // exclusive. Use one set, not both. + // + // Batch mode: + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqstart_* and seqlen_* pointers must be nullptr. + // + // Without padding: + // (Note: Physical length equals logical length) + // + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array + // size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr. + // + // Batch mode: + // - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr. + // + const void* seqstart_q_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqstart_k_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; + + // BLOCKSCALE parameters + ck_tile::index_t nhead_stride_q_descale = 0; + ck_tile::index_t nhead_stride_k_descale = 0; + ck_tile::index_t nhead_stride_v_descale = 0; + ck_tile::index_t batch_stride_q_descale = 0; + ck_tile::index_t batch_stride_k_descale = 0; + ck_tile::index_t batch_stride_v_descale = 0; + ck_tile::index_t block_scale_size_q = 0; + ck_tile::index_t block_scale_size_k = 0; + const void* block_scale_seqstart_q_ptr = nullptr; + const void* block_scale_seqstart_k_ptr = nullptr; +}; + +template +auto sageattn_fwd_create_kargs_and_grids(sageattn_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(SageAttnKernel::kIsGroupMode) + { + return SageAttnKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, + args.batch_stride_v_descale, + args.block_scale_size_q, + args.block_scale_size_k, + args.block_scale_seqstart_q_ptr, + args.block_scale_seqstart_k_ptr, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + else + { // create batch mode kernel arguments + return SageAttnKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_o, + args.batch_stride_q_descale, + args.batch_stride_k_descale, + args.batch_stride_v_descale, + args.block_scale_size_q, + args.block_scale_size_k, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + }(); + + if constexpr(SageAttnKernel::kIsGroupMode) + { + dim3 grids = SageAttnKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = SageAttnKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + +// this is used to pattern-match internal kernel implementation, not to instantiate kernel +template +struct sageattn_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto SageAttnPipelineEnum = SageAttnPipelineEnum_; + using AttnMask = ck_tile::remove_cvref_t; + static constexpr auto QScaleEnum = QScaleEnum_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + +template +float sageattn_fwd_(const ck_tile::stream_config&, sageattn_fwd_args); + +// This is the public API, will be generated by script +struct sageattn_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + quant_scale_enum qscale_type; + bool skip_min_seqlen_q = false; + // TODO: padding check is inside this api +}; +float sageattn_fwd(sageattn_fwd_traits, sageattn_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp b/example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp new file mode 100644 index 0000000000..a28731648c --- /dev/null +++ b/example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp @@ -0,0 +1,1154 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/ref/naive_attention.hpp" +#include "sageattn_fwd.hpp" +#include "utils.hpp" +#include "ck_tile/utility/json_dump.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +inline void dump_sageattn_fwd_json_results(Args&&... args) +{ + dump_fmha_fwd_json_results(std::forward(args)...); +} + +enum class fwd_result +{ + success, + failure, + invalid_args, + no_instance, +}; + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + // atol=0.18: Q, K, V quantization (FP8 E4M3 ~0.0625/element) + 2 GEMM accumulations + // + softmax sensitivity. Empirically tuned; tightening below 0.15 causes false positives. + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + // atol=0.18: K, V still FP8 (dominant error source). Matches FP8xFP8 despite + // lower Q quantization error (int8 ~0.0078 vs fp8 ~0.0625) to avoid test fragility. + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + // atol=0.19: +0.01 over FP8 due to coarse Q quantization (int4 ~0.125, only 16 levels). + // Attention pattern becomes "blocky"; softmax amplifies logit clustering. + double rtol = 1e-2; + double atol = 1.9e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template +fwd_result sageattn_fwd_run(mode_enum mode, + ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + std::vector seqlen_qs, + std::vector seqlen_ks, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + std::vector seqlen_qpads, + std::vector seqlen_kpads, + std::vector q_eff_lens_per_batch, + std::vector kv_eff_lens_per_batch, + bool i_perm, + bool o_perm, + float scale_s, + bool is_v_rowmajor, + std::string mask_str, + std::string qscale_str, + std::string init_method, + uint32_t seed, + int do_validation, + const ck_tile::stream_config& stream_config, + std::optional json = std::nullopt) +{ + const std::string data_type = []() { + if constexpr(std::is_same_v) + return "fp16"; + else if constexpr(std::is_same_v) + return "bf16"; + else if constexpr(std::is_same_v) + return "fp8bf16"; + else if constexpr(std::is_same_v) + return "i8fp8bf16"; + else if constexpr(std::is_same_v) + return "i4fp8bf16"; + else + static_assert(false); + }(); + + if(nhead_k < 0) + nhead_k = nhead; + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return fwd_result::invalid_args; + } + + std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}()); + auto next_seed = [&random_engine]() { return static_cast(random_engine()); }; + + if(hdim_v < 0) + hdim_v = hdim_q; + + // Check padding usage + const bool has_group_q_padding = + mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0); + const bool has_group_k_padding = + mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0); + const bool has_group_padding = has_group_q_padding || has_group_k_padding; + const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty(); + const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty(); + const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding; + + std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = + generate_missing_seqlens(mode, + batch, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + /*seqlen_k_min=*/0, + false, // need_append_kvcache not supported + random_engine); + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb]) + { + std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl; + return fwd_result::invalid_args; + } + if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb]) + { + std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl; + return fwd_result::invalid_args; + } + } + + if(scale_s == .0f) + scale_s = 1.0f / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + mask_info mask = + mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore + + quant_scale_info qscale = quant_scale_info::decode(qscale_str); + + // PERWARP mode: Q=32 (warp size), K=64 (2x warp size) + // BLOCKSCALE mode: Q=128 (tile size), K=128 + // PERTHREAD mode: Q=4 (tokens/scale), K=16 (tokens/scale) + // Note: V uses per-channel scale, not block scale + const ck_tile::index_t block_scale_size_q_ = (qscale.type == quant_scale_enum::perwarp) ? 32 + : (qscale.type == quant_scale_enum::perthread) + ? 4 + : 128; + const ck_tile::index_t block_scale_size_k_ = (qscale.type == quant_scale_enum::perthread) ? 16 + : (qscale.type == quant_scale_enum::perwarp) ? 64 + : 128; + + // blockscale, perwarp, or perthread + const bool qscale_uses_bwp = qscale.type == quant_scale_enum::blockscale || + qscale.type == quant_scale_enum::perwarp || + qscale.type == quant_scale_enum::perthread; + + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + + // Optional batch-mode cumulative seqlen overrides + std::vector cuq_cum, cukv_cum; + if(mode == mode_enum::batch) + { + auto calculate_cumulative = [&](std::vector& per_batch_vec, + std::vector& cum_vec) { + if(!per_batch_vec.empty() && per_batch_vec[0] != -1) + { + if(per_batch_vec.size() < static_cast(batch)) + { + per_batch_vec.resize(batch, per_batch_vec.back()); + } + cum_vec.resize(batch + 1); + cum_vec[0] = 0; + for(int i = 0; i < batch; ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + } + }; + + calculate_cumulative(q_eff_lens_per_batch, cuq_cum); + calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); + } + + using TypeConfig = SageAttentionFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + constexpr ck_tile::index_t q_packed_size = + ck_tile::is_packed_type_v ? ck_tile::numeric_traits::PackedSize : 1; + constexpr ck_tile::index_t k_packed_size = + ck_tile::is_packed_type_v ? ck_tile::numeric_traits::PackedSize : 1; + constexpr bool is_q_i4 = std::is_same_v; + constexpr bool is_k_i4 = std::is_same_v; + constexpr bool need_q_i4_permute = is_q_i4 && !is_k_i4; + constexpr bool need_k_i4_permute = is_k_i4 && !is_q_i4; + const ck_tile::index_t hdim_q_storage_q = hdim_q / q_packed_size; + const ck_tile::index_t hdim_q_storage_k = hdim_q / k_packed_size; + if constexpr(ck_tile::is_packed_type_v) + { + if(hdim_q % q_packed_size != 0) + { + std::cerr << "hdim_q must be divisible by packed size for QDataType, got hdim_q=" + << hdim_q << ", packed_size=" << q_packed_size << std::endl; + return fwd_result::invalid_args; + } + if constexpr(need_q_i4_permute) + { + if(hdim_q % 8 != 0) + { + std::cerr << "hdim_q must be divisible by 8 for pk_int4_t QDataType, got hdim_q=" + << hdim_q << std::endl; + return fwd_result::invalid_args; + } + } + } + if constexpr(ck_tile::is_packed_type_v) + { + if(hdim_q % k_packed_size != 0) + { + std::cerr << "hdim_q must be divisible by packed size for KDataType, got hdim_q=" + << hdim_q << ", packed_size=" << k_packed_size << std::endl; + return fwd_result::invalid_args; + } + if constexpr(need_k_i4_permute) + { + if(hdim_q % 8 != 0) + { + std::cerr << "hdim_q must be divisible by 8 for pk_int4_t KDataType, got hdim_q=" + << hdim_q << std::endl; + return fwd_result::invalid_args; + } + } + } + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + + static_cast(2) * mask.get_unmaskarea() * hdim_v); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q_storage_q + + sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q_storage_k + + sizeof(VDataType) * hdim_v * real_seqlen_k); + } + } + + static const auto get_lengths = [](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_qs[0] + : (has_group_q_padding && !seqstart_q_with_padding_host.empty() + ? seqstart_q_with_padding_host.back() + : seqstart_q_host.back())); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_ks[0] + : (has_group_k_padding && !seqstart_k_with_padding_host.empty() + ? seqstart_k_with_padding_host.back() + : seqstart_k_host.back())); + + // Calculate number of blocks for blockscale mode + ck_tile::index_t i_block_scale_q = 0; + ck_tile::index_t i_block_scale_k = 0; + std::vector block_scale_seqstart_q_host{0}; + std::vector block_scale_seqstart_k_host{0}; + + if(mode == mode_enum::group) + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); + i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_k_); + block_scale_seqstart_q_host.push_back(i_block_scale_q); + block_scale_seqstart_k_host.push_back(i_block_scale_k); + } + } + + const ck_tile::index_t num_block_scale_q = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_) + : i_block_scale_q; + const ck_tile::index_t num_block_scale_k = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_k_) + : i_block_scale_k; + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( + is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); + + ck_tile::HostTensor q_descale_host( + qscale_uses_bwp ? std::array{shape_batch, nhead, num_block_scale_q} + : std::array{1, 1, 1}); + ck_tile::HostTensor k_descale_host( + qscale_uses_bwp ? std::array{shape_batch, nhead_k, num_block_scale_k} + : std::array{1, 1, 1}); + // BLOCKSCALE, PERWARP, and PERTHREAD V all use per-channel scale (col-major layout) + ck_tile::HostTensor v_descale_host( + qscale_uses_bwp ? std::array{batch, nhead_k, hdim_v} + : std::array{1, 1, 1}); + + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + const auto get_dtype_max = []() { + if constexpr(ck_tile::is_packed_type_v) + return 7.0f; + else + return ck_tile::type_convert(ck_tile::numeric::max()); + }; + + if(init_method == "ui" || init_method == "0") + { + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); + } + + else if(init_method == "ni") + { + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); + } + else if(init_method == "uf" || init_method == "1") + { + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(v_host); + } + else if(init_method == "nf") + { + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(v_host); + } + else if(init_method == "tf" || init_method == "2") + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + } + else if(init_method == "3") + { + float q_dtype_max = get_dtype_max.template operator()(); + float k_dtype_max = get_dtype_max.template operator()(); + float v_dtype_max = get_dtype_max.template operator()(); + + ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, next_seed()}(q_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, next_seed()}(k_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}(v_host); + } + if(qscale.type == quant_scale_enum::pertensor) + { + float q_dtype_max = get_dtype_max.template operator()(); + float k_dtype_max = get_dtype_max.template operator()(); + float v_dtype_max = get_dtype_max.template operator()(); + + float qkv_max = 3.f; + q_descale_host(0) = qkv_max / q_dtype_max; + k_descale_host(0) = qkv_max / k_dtype_max; + v_descale_host(0) = qkv_max / v_dtype_max; + } + else if(qscale_uses_bwp) + { + float q_dtype_max = get_dtype_max.template operator()(); + float k_dtype_max = get_dtype_max.template operator()(); + float v_dtype_max = get_dtype_max.template operator()(); + + float qkv_max = 3.f; + float max_descale_q = qkv_max / q_dtype_max; + float max_descale_k = qkv_max / k_dtype_max; + float max_descale_v = qkv_max / v_dtype_max; + + ck_tile::FillUniformDistribution{max_descale_q * 0.8f, max_descale_q, next_seed()}( + q_descale_host); + ck_tile::FillUniformDistribution{max_descale_k * 0.8f, max_descale_k, next_seed()}( + k_descale_host); + + // BLOCKSCALE, PERWARP, and PERTHREAD V all use per-channel scale (shape: [batch, nhead_k, + // hdim_v]) + ck_tile::FillUniformDistribution{max_descale_v * 0.8f, max_descale_v, next_seed()}( + v_descale_host); + } + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + // Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding + // enabled) + ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0); + // Buffers for key/value per-sequence logical (unpadded) lengths (used in group mode with + // padding enabled) + ck_tile::DeviceMem seqlen_k_buf(has_group_k_padding ? seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); + // Must match args.block_scale_seqstart_* (group + bs/pw/pth only). bf16 validation (qscale=n) + // never binds these pointers; allocating only when the kernel uses them avoids empty uploads. + const bool need_block_scale_seqstart_buf = mode == mode_enum::group && qscale_uses_bwp; + ck_tile::DeviceMem block_scale_seqstart_q_buf( + need_block_scale_seqstart_buf ? block_scale_seqstart_q_host.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem block_scale_seqstart_k_buf( + need_block_scale_seqstart_buf ? block_scale_seqstart_k_host.size() * sizeof(int32_t) : 0); + + if constexpr(need_q_i4_permute) + { + auto q_host_dev = q_host; + ck_tile::permute_vectors_i4x4_b(q_host_dev); + q_buf.ToDevice(q_host_dev.data()); + } + else + { + q_buf.ToDevice(q_host.data()); + } + if constexpr(need_k_i4_permute) + { + auto k_host_dev = k_host; + ck_tile::permute_vectors_i4x4_b(k_host_dev); + k_buf.ToDevice(k_host_dev.data()); + } + else + { + k_buf.ToDevice(k_host.data()); + } + v_buf.ToDevice(v_host.data()); + q_descale_buf.ToDevice(q_descale_host.data()); + k_descale_buf.ToDevice(k_descale_host.data()); + v_descale_buf.ToDevice(v_descale_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + // Keep logical starts in seqstart_k; pass padded K via separate pointer + seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_q_padded_buf.ToDevice( + seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); + seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr + : seqstart_k_with_padding_host.data()); + cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); + cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); + seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr); + seqlen_k_buf.ToDevice(has_group_k_padding ? seqlen_ks.data() : nullptr); + block_scale_seqstart_q_buf.ToDevice( + need_block_scale_seqstart_buf ? block_scale_seqstart_q_host.data() : nullptr); + block_scale_seqstart_k_buf.ToDevice( + need_block_scale_seqstart_buf ? block_scale_seqstart_k_host.data() : nullptr); + + // clang-format off + auto layout_str = [&](bool permute){ + if(permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if(iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + + std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm) + << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] + << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s + << ", qscale:" << qscale << ", mask:" << mask + << ", v:" << (is_v_rowmajor ? "r" : "c"); + // Padding / effective length diagnostic logging + auto print_vec = [&](const char* label, const std::vector& v) { + if(v.empty()) + return; + std::cout << ", " << label << ":["; + for(std::size_t i = 0; i < v.size(); ++i) + { + if(i) + std::cout << ","; + std::cout << v[i]; + } + std::cout << "]"; + }; + + if(has_group_padding) + { + bool has_qpad = !seqstart_q_with_padding_host.empty(); + bool has_kpad = (seqlen_kpads[0] >= 0); + if(has_qpad) + { + print_vec("q_logical", seqlen_qs); + print_vec("q_padded", seqlen_qpads); + } + if(has_kpad) + { + print_vec("k_logical", seqlen_ks); + print_vec("k_padded", seqlen_kpads); + } + } + else if(has_batch_padding) + { + // derive effective lengths from cumulative arrays if present + if(!cuq_cum.empty()) + { + std::vector eff_q(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_q[b_i] = static_cast(cuq_cum[b_i + 1] - cuq_cum[b_i]); + print_vec("q_eff", eff_q); + } + if(!cukv_cum.empty()) + { + std::vector eff_kv(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_kv[b_i] = static_cast(cukv_cum[b_i + 1] - cukv_cum[b_i]); + print_vec("kv_eff", eff_kv); + } + } + + std::cout << std::flush; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + traits.is_group_mode = (mode == mode_enum::group); + traits.mask_type = mask.type; + traits.qscale_type = qscale.type; + }; + + const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; + }(); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + // setup split_stride_* arguments (only used in split-kv kernel) + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.batch = batch; + args.seqlen_q = shape_seqlen_q; // unused in group mode + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + // Setup sageattn_fwd_args + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + + args.stride_o = stride_o; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + + // BLOCKSCALE/PERWARP/PERTHREAD parameters + if(qscale_uses_bwp) + { + args.nhead_stride_q_descale = num_block_scale_q; + args.nhead_stride_k_descale = num_block_scale_k; + // BLOCKSCALE, PERWARP, and PERTHREAD V all use per-channel scale: stride = hdim_v + args.nhead_stride_v_descale = hdim_v; + + if(mode == mode_enum::batch) + { + args.batch_stride_q_descale = nhead * num_block_scale_q; + args.batch_stride_k_descale = nhead_k * num_block_scale_k; + // BLOCKSCALE, PERWARP, and PERTHREAD V all use per-channel scale: batch_stride = + // nhead_k * hdim_v + args.batch_stride_v_descale = nhead_k * hdim_v; + } + else // group mode + { + // BLOCKSCALE, PERWARP, and PERTHREAD all use block_scale_seqstart in group mode + // They differ only in block size: BLOCKSCALE (Q:128, K:128), PERWARP (Q:32, K:64), + // PERTHREAD (Q:4, K:16) + args.block_scale_seqstart_q_ptr = block_scale_seqstart_q_buf.GetDeviceBuffer(); + args.block_scale_seqstart_k_ptr = block_scale_seqstart_k_buf.GetDeviceBuffer(); + // BLOCKSCALE, PERWARP, and PERTHREAD V all use per-channel scale: batch_stride = + // nhead_k * hdim_v + args.batch_stride_v_descale = nhead_k * hdim_v; + } + + args.block_scale_size_q = block_scale_size_q_; + args.block_scale_size_k = block_scale_size_k_; + } + + // Sequence length and padding parameters (mode-specific) + if(mode == mode_enum::group) + { + // Group mode: use physical (padded) cumulative starts + logical per-sequence + // lengths + + // Physical cumulative starts (including padding) + args.seqstart_q_ptr = has_group_q_padding && !seqstart_q_with_padding_host.empty() + ? seqstart_q_padded_buf.GetDeviceBuffer() + : seqstart_q.GetDeviceBuffer(); + args.seqstart_k_ptr = has_group_k_padding && !seqstart_k_with_padding_host.empty() + ? seqstart_k_padded_buf.GetDeviceBuffer() + : seqstart_k.GetDeviceBuffer(); + + // Logical (unpadded) per-sequence lengths, used when padding is enabled + args.seqlen_q_ptr = (has_group_q_padding && !seqstart_q_with_padding_host.empty()) + ? seqlen_q_buf.GetDeviceBuffer() + : nullptr; + args.seqlen_k_ptr = (has_group_k_padding && !seqstart_k_with_padding_host.empty()) + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr; + // Cumulative lengths not used in group mode + args.cu_seqlen_q_ptr = nullptr; + args.cu_seqlen_k_ptr = nullptr; + } + else // mode == mode_enum::batch + { + // Batch mode: use cumulative logical lengths for tail padding + + // seqstart pointers not used in batch mode + args.seqstart_q_ptr = nullptr; + args.seqstart_k_ptr = nullptr; + + // seqlen_q_ptr/seqlen_k_ptr not used in batch mode + args.seqlen_q_ptr = nullptr; + args.seqlen_k_ptr = nullptr; + + // Cumulative logical lengths for effective length handling + args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty() + ? cu_seqlen_q_buf.GetDeviceBuffer() + : nullptr; + args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty() + ? cu_seqlen_kv_buf.GetDeviceBuffer() + : nullptr; + } + }; + + // Run main SageAttention forward kernel + sageattn_fwd_traits sageattn_traits; + init_traits(sageattn_traits); + + sageattn_fwd_args sageattn_args; + init_args(sageattn_args); + + const float ave_time = sageattn_fwd(sageattn_traits, sageattn_args, stream_config); + if(ave_time < 0.0f) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return fwd_result::no_instance; + } + const float tflops = static_cast(flop) / 1.E9 / ave_time; + const float gb_per_sec = num_byte / 1.E6 / ave_time; + if(stream_config.time_kernel_) + { + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) + << gb_per_sec << " GB/s" << std::flush; + } + + bool pass = true; + if(do_validation == 0) + { + std::cout << std::flush << std::endl; + } + else + { + o_buf.FromDevice(o_host.data()); + + constexpr bool supports_qscale = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + float scale_s_host = scale_s; + float scale_p_host = 1.0f; + float scale_o_host = 1.0f; + + if(qscale.type == quant_scale_enum::pertensor) + { + scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0); + scale_p_host = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o_host = v_descale_host(0) / scale_p_host; + } + + auto p_compute_element_func = [&]() { + if constexpr(supports_qscale) + return ck_tile::scales{scale_p_host}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v && supports_qscale) + return ck_tile::make_composes(ck_tile::saturates{}, + ck_tile::scales{scale_o_host}); + else if constexpr(supports_qscale) + return ck_tile::scales{scale_o_host}; + else + return ck_tile::identity{}; + }(); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + if(mode == mode_enum::batch) + { + if(!cuq_cum.empty()) + { + real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; + } + if(!cukv_cum.empty()) + { + real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; + } + } + + // adjust matrix index according to the mode + const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t cache_b_idx = b_idx; + // Use physical offset if padding info is valid (not -1) and buffers are available + const ck_tile::index_t query_offset = + (mode == mode_enum::batch + ? 0 + : ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0) + ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0) + ? seqstart_k_host[wb] + : seqstart_k_with_padding_host[wb])); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + + { + // clang-format off + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + // clang-format on + } + + { + if(is_v_rowmajor) + { + // clang-format off + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + } + else + { + // clang-format off + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + // clang-format on + } + } + + // reference + // For packed types (pk_int4_t), unpack to float for host reference GEMM + auto unpack_to_float = [](const auto& packed_tensor) { + auto dims = packed_tensor.mDesc.get_lengths(); + ck_tile::HostTensor unpacked({static_cast(dims[0]), + static_cast(dims[1]), + static_cast(dims[2])}); + unpacked.ForEach([&](auto& self, auto idx) { + auto packed = packed_tensor(idx[0], idx[1], idx[2]); + auto fp32x2 = ck_tile::pk_int4_t_to_fp32x2_t(packed); + self(idx) = (idx[2] % 2 == 0) ? fp32x2[0] : fp32x2[1]; + }); + return unpacked; + }; + + if(qscale_uses_bwp) + { + const ck_tile::index_t q_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; + const ck_tile::index_t k_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; + if constexpr(ck_tile::is_packed_type_v) + { + auto q_f32 = unpack_to_float(q_host_ref); + auto k_f32 = unpack_to_float(k_host_ref); + ck_tile::reference_batched_quant_gemm( + q_f32, + k_f32, + s_host_ref, + ck_tile::idx_identity{}, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return value * scale_s * + q_descale_host(b_idx, + std::get<0>(idx), + q_offset + + std::get<1>(idx) / block_scale_size_q_) * + k_descale_host(b_idx, + std::get<0>(idx) / nr, + k_offset + + std::get<2>(idx) / block_scale_size_k_); + }); + } + else + { + ck_tile::reference_batched_quant_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::idx_identity{}, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return value * scale_s * + q_descale_host(b_idx, + std::get<0>(idx), + q_offset + + std::get<1>(idx) / block_scale_size_q_) * + k_descale_host(b_idx, + std::get<0>(idx) / nr, + k_offset + + std::get<2>(idx) / block_scale_size_k_); + }); + } + } + else + { + if constexpr(ck_tile::is_packed_type_v) + { + auto q_f32 = unpack_to_float(q_host_ref); + auto k_f32 = unpack_to_float(k_host_ref); + ck_tile:: + reference_batched_gemm( + q_f32, + k_f32, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); + } + else + { + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); + } + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, SageAttnMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + // Match device: kernel sets is_top_left from (mask_type == MASK_FROM_TOP_LEFT); + // window_generic maps to MASK_GENERIC, so is_top_left is false (not the default). + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, 0, real_seqlen_q, real_seqlen_k, false)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window< + SageAttnMasks::CausalMask>(mask.left, + mask.right, + 0, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window< + SageAttnMasks::GenericMask>(mask.left, + mask.right, + 0, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + const ck_tile::HostTensor masked_s_host_ref = s_host_ref; + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func); + + if(qscale_uses_bwp) + { + // BLOCKSCALE, PERWARP, and PERTHREAD V all use per-channel scale (col-major) + // v_descale shape: [batch, nhead_k, hdim_v] + // Access by channel index: std::get<1>(idx) is the hdim dimension + ck_tile:: + reference_batched_quant_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return ck_tile::type_convert(value) * + v_descale_host(wb, + std::get<0>(idx) / nr, + std::get<1>(idx)); // channel index + }, + ck_tile::idx_identity{}); + } + else + { + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + } + + ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck_tile::check_err(o_host_result, + o_host_ref, + std::string("OUT Error: Incorrect results!"), + rtol, + atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "OUT mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q (logical): " << seqstart_q_host << std::endl + << "\tseqstart_q (physical): " << seqstart_q_with_padding_host + << std::endl + << "\tseqstart_k (logical): " << seqstart_k_host << std::endl + << "\tseqstart_k (physical): " << seqstart_k_with_padding_host + << std::endl + << "\tquery_offset used: " << query_offset << std::endl + << "\tkey_offset used: " << key_offset << std::endl; + + break; + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + if(json) + { + dump_sageattn_fwd_json_results( + *json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + 0.0f, // p_drop (dropout disabled for sageattention) + false, // lse (always disabled for sageattention) + [&qscale]() { + std::ostringstream ss; + qscale.serialize(ss); + return ss.str(); + }(), + "no_bias", + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass ? fwd_result::success : fwd_result::failure; +} diff --git a/example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh b/example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh new file mode 100755 index 0000000000..ae0e120c05 --- /dev/null +++ b/example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh @@ -0,0 +1,162 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# +# SageAttention forward smoke tests - structure mirrors +# example/ck_tile/01_fmha/script/smoke_test_fwd.sh +# +# Run from the ComposableKernel *build* directory (after ninja), same as FMHA: +# cd build && ninja tile_example_sageattn_fwd +# bash ../example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh +# +# Optional: VERBOSE=1 enables bash -x. CURR_FAILS_FILE / KNOWN_FAILS_FILE override fail logs. + +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +EXE_NAME=tile_example_sageattn_fwd +EXE="$(find . -name "$EXE_NAME" -type f 2>/dev/null | head -n 1)" +KNAME=1 +GPU_arch=${GPU_arch:-} +if [ -z "$GPU_arch" ]; then + GPU_arch=$(rocminfo 2>/dev/null | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}' || echo "unknown") +fi + +export CK_WARMUP=0 +export CK_REPEAT=1 + +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"sageattn_fwd_fails_${GPU_arch}.txt"} +rm -f "$CURR_FAILS_FILE" +touch "$CURR_FAILS_FILE" +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/sageattn_fwd_known_fails_${GPU_arch}.txt"} + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' + +if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then + echo "ERROR: $EXE_NAME not found under cwd ($(pwd)). Build with: ninja $EXE_NAME" >&2 + exit 1 +fi + +run_exe() { + set +e + $EXE "$@" + local ret=$? + if [ $ret -ne 0 ]; then + echo "$EXE_NAME $*" >>"$CURR_FAILS_FILE" + fi + set -e +} + +# Core FP8xBF16 cases aligned with FMHA smoke_test_fwd.sh (lines 80-87): batch/group shapes, +# masks, GQA, short seqlen, k-only pad. Sweeps blockscale (2) vs per-warp (3) and layouts. +run_fp8bf16_smoke() { + local qscale + local perm + for qscale in 2 3; do + for perm in 0 1; do + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=2 -h_k=1 -d=128 -d_v=128 -s=55 -s_k=256 \ + -mask=1 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=3 -d=128 -s=100 -s_k=51 -mask=0 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=99 -s_k=256 \ + -mask=1 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1024 -s_k=256 \ + -mask=2 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=3 -s_k=99 -mask=2 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=3 -h=2 -h_k=1 -d=128 -s=200 -s_k=520 \ + -mask=t:128,30 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -s=99 -s_k=32 -mask=b:4,35 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=33 -s_k=0 -mask=2 + run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \ + -kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 \ + -s_kpad=32 -mask=2 + done + done +} + +# Extra FP8: explicit causal string, xformer window, per-tensor / per-thread quant, V col-major. +run_fp8bf16_extras() { + run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=t:-1,0 + run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=1 -operm=1 -vlayout=c -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=2 -h=4 -d=128 -s=256 -s_k=256 -mask=t + run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=256 -s_k=256 -mask=xt:64 + run_exe -prec=fp8bf16 -init=3 -qscale=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=0 + run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=64 -s_k=64 -mask=0 +} + +# Group mode + physical padding (same intent as FMHA run_padding_smoke_tests, Sage-only flags). +run_group_and_padding_smoke() { + run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1 + # group + PERTHREAD: block_scale_seqstart_* must be allocated (same as bs/pw) + run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1 + run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=1 -b=4 -h=8 -h_k=8 -d=128 -s=1024,768,512,256 -s_k=1024,768,512,256 \ + -mask=0 -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=0 \ + -q_eff_lens=960,512,384,256 -kv_eff_lens=960,512,384,256 +} + +# BF16 (no quant): pipeline sanity only; not a shipped Sage mode (see example --help prec). +run_bf16_pipeline_smoke() { + run_exe -prec=bf16 -init=1 -qscale=n -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1 + run_exe -prec=bf16 -init=1 -qscale=n -iperm=1 -operm=1 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=1 -h=4 -h_k=1 -d=128 -s=256 -s_k=128 -mask=t:32,32 +} + +# int8 / int4 x fp8xbf16 (hdim divisible by 8 for int4) +run_int_quant_smoke() { + run_exe -prec=i8fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1 + run_exe -prec=i4fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \ + $COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=t +} + +if [ "${VERBOSE:-0}" = 1 ]; then + set -x +fi + +run_fp8bf16_smoke +run_fp8bf16_extras +run_group_and_padding_smoke +run_bf16_pipeline_smoke +run_int_quant_smoke + +set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f "$KNOWN_FAILS_FILE" ]; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" "$KNOWN_FAILS_FILE"; then + echo "Known fail: $line" + known_fails_count=$((known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$((new_fails_count + 1)) + fi + done <"$CURR_FAILS_FILE" +else + new_fails_count=$(wc -l <"$CURR_FAILS_FILE") + echo "No known fails file, all fails ($new_fails_count) are new:" + if [ "$new_fails_count" -gt 0 ]; then + cat "$CURR_FAILS_FILE" + fi +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $((new_fails_count != 0)) diff --git a/example/ck_tile/49_sageattention/utils.hpp b/example/ck_tile/49_sageattention/utils.hpp new file mode 100644 index 0000000000..27c97f8383 --- /dev/null +++ b/example/ck_tile/49_sageattention/utils.hpp @@ -0,0 +1,254 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/container/span.hpp" + +enum class mode_enum +{ + batch = 0, + group +}; + +inline std::ostream& operator<<(std::ostream& stream, mode_enum mode) +{ + return stream << (mode == mode_enum::batch ? "batch" : "group"); +} + +template +inline std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +inline std::vector to_seqstarts(ck_tile::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +template +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlen_avg, + int32_t seqlen_min, // if not negative, clamp min + int32_t seqlen_max, // if not negative, clamp max + RandomEngine& random_engine) +{ + assert(0 < count); + + seqlen_min = (0 < seqlen_min ? seqlen_min : 1); + seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits::max()); + assert(seqlen_min <= seqlen_max); + + std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); + + if(mode == mode_enum::group && 1 < count) + { + using size_type = std::vector::size_type; + + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens is in range [seqlen_min, seqlen_max] + if(seqlens[to_decrease] == seqlen_min) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + if(seqlens[to_increase] >= seqlen_max) + { + continue; + } + + --seqlens[to_decrease]; + ++seqlens[to_increase]; + } + } + + return seqlens; +} + +// return random integer generated uniformly in range [low, high] +template +auto randint(Int low, + Int high, + RandomEngine& random_engine) -> std::enable_if_t, Int> +{ + std::uniform_int_distribution dist(low, high); + return dist(random_engine); +} + +// return random integers generated uniformly in range [low, high] +template +auto randints(ForwardIterator first, + ForwardIterator last, + Int low, + Int high, + RandomEngine& random_engine) -> std::enable_if_t> +{ + std::uniform_int_distribution dist(low, high); + + std::generate(first, last, [&] { return dist(random_engine); }); +} + +/* + * generate missing values in *_val randomly when the number of values is smaller than batch + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +template +std::tuple, + std::vector, + std::vector, + std::vector> +generate_missing_seqlens(mode_enum mode, + ck_tile::index_t batch, + const std::vector& q_val, + const std::vector& k_val, + const std::vector& q_pad_val, + const std::vector& k_pad_val, + ck_tile::index_t seqlen_k_min, + bool need_append_kvcache, + RandomEngine& random_engine) +{ + if(mode == mode_enum::batch) + { + ck_tile::index_t q = q_val[0]; + ck_tile::index_t k = k_val[0]; + + auto s_q = std::vector(batch, q); + auto s_k = [&] { + const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); + std::vector seqlen_ks(batch, seqlen_k_max); + + if(1 < batch && need_append_kvcache) + { + // to keep the original s_k value, we always use seqlen_k_max in first batch + randints(std::next(seqlen_ks.begin()), + seqlen_ks.end(), + seqlen_k_min, + seqlen_k_max, + random_engine); + return seqlen_ks; + } + + return seqlen_ks; + }(); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + auto s_qpad = std::vector(batch, -1); + // s_k should be greater than or equal to seqlen_k_min if provided + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + + return std::make_tuple(s_q, s_k, s_qpad, s_kpad); + } + else + { + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + std::vector s_qpad; + ck_tile::index_t idx = 0; + for(; idx < std::min(static_cast(q_val.size()), batch); ++idx) + { + ck_tile::index_t q = q_val[idx]; + ck_tile::index_t k = + k_val[std::min(idx, static_cast(k_val.size()) - 1)]; + ck_tile::index_t kp = + k_pad_val.empty() + ? -1 + : k_pad_val[std::min(idx, static_cast(k_pad_val.size()) - 1)]; + + ck_tile::index_t qp = + q_pad_val.empty() + ? -1 + : q_pad_val[std::min(idx, static_cast(q_pad_val.size()) - 1)]; + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + s_qpad.push_back(qp); + + // s_k should be greater than or equal to seqlen_k_min + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + } + if(idx < batch) + { + auto rem_q = + generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine); + auto rem_k = generate_seqlens( + mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back()); + } + return std::make_tuple(s_q, s_k, s_qpad, s_kpad); + } +} + +template +std::enable_if_t> iota_shuffle(RandomAccessIterator first, + RandomAccessIterator last, + Int value, + RandomEngine& random_engine) +{ + std::iota(first, last, value); + std::shuffle(first, last, random_engine); +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index dda9156992..5b9b4d9614 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -31,6 +31,7 @@ add_subdirectory(38_block_scale_gemm) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) add_subdirectory(42_mx_gemm) +add_subdirectory(49_sageattention) add_subdirectory(50_sparse_attn) add_subdirectory(51_tile_distr_enc_reg_map) if(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c2ddaa2730..329703614e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -530,4 +530,10 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed = WarpGemmImpl>>; +template +using WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution = + WarpGemmImpl, + 2, + swizzle_factor>>; } // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp b/include/ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp new file mode 100644 index 0000000000..44be382379 --- /dev/null +++ b/include/ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockSageAttentionQuantScaleEnum +{ + NO_SCALE = 0, + PERTENSOR = 1, + BLOCKSCALE = 2, + PERWARP = 3, + PERTHREAD = 4, +}; + +template +struct BlockSageAttentionQuantScaleEnumToStr; + +template <> +struct BlockSageAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockSageAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "pertensor"; +}; +template <> +struct BlockSageAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "blockscale"; +}; +template <> +struct BlockSageAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "perwarp"; +}; +template <> +struct BlockSageAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "perthread"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp b/include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp new file mode 100644 index 0000000000..48dec0e796 --- /dev/null +++ b/include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp @@ -0,0 +1,1026 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// P[seqlen_q, seqlen_k] = Softmax(S'[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct SageAttnFwdKernel +{ + using SageAttnPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = SageAttnPipeline::kBlockSize; + + static constexpr ck_tile::index_t kBlockPerCu = SageAttnPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = SageAttnPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = SageAttnPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = SageAttnPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = SageAttnPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = SageAttnPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = SageAttnPipeline::kPadHeadDimV; + // logits_soft_cap is always disabled + static constexpr auto QScaleEnum = SageAttnPipeline::Problem::QScaleEnum; + static constexpr bool kSkipMinSeqlenQ = SageAttnPipeline::Problem::kSkipMinSeqlenQ; + + using AttentionVariant = ck_tile::remove_cvref_t; + using AttnMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = AttnMask::IsMasking; + + static constexpr bool kUseAsyncCopy = SageAttnPipeline::Policy::AsyncCopy; + + // Distinct empty bases (I = 0 mask slot, 1 qscale slot, 2 min_seqlen_q slot) avoid duplicate + // base-class issues under multiple inheritance. + template + struct SageAttnFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct SageAttnFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct SageAttnFwdMaskKargs + { + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct SageAttnFwdCommonQScaleKargs + { + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; + }; + + struct SageAttnFwdCommonBlockScaleKargs : public SageAttnFwdCommonQScaleKargs + { + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + + /// Host must match TileSageAttnTraits / Problem; validated in MakeKargs (device uses + /// Problem::kBlockScaleSizeQ/K). + ck_tile::index_t block_scale_size_q = 0; + ck_tile::index_t block_scale_size_k = 0; + }; + + struct SageAttnFwdBatchBlockScaleKargs : public SageAttnFwdCommonBlockScaleKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + }; + + struct SageAttnFwdGroupBlockScaleKargs : public SageAttnFwdCommonBlockScaleKargs + { + const int32_t* block_scale_seqstart_q_ptr = nullptr; + const int32_t* block_scale_seqstart_k_ptr = nullptr; + ck_tile::index_t batch_stride_v_descale; + }; + + struct SageAttnFwdSkipMinSeqlenQKargs + { + ck_tile::index_t min_seqlen_q = 0; + }; + + struct SageAttnFwdBatchModeKargs + : SageAttnFwdCommonKargs, + std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTENSOR, + SageAttnFwdCommonQScaleKargs, + std::conditional_t>> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + + // Optional cumulative sequence length pointers for batch mode + // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. + const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD + const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD + }; + + struct SageAttnFwdGroupModeKargs + : SageAttnFwdCommonKargs, + std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTENSOR, + SageAttnFwdCommonQScaleKargs, + std::conditional_t>>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; + const int32_t* seqlen_k_ptr; + + // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays + const int32_t* cu_seqlen_q_ptr = nullptr; + const int32_t* cu_seqlen_k_ptr = nullptr; + }; + + using Kargs = + std::conditional_t; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + template + CK_TILE_HOST static std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_k, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + static_cast(scale_s * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // mask or SageAttnFwdEmptyKargs<0> + {}, // qscale or SageAttnFwdEmptyKargs<1> + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + using PipelineProblem = typename SageAttnPipeline::Problem; + + if(block_scale_size_q != PipelineProblem::kBlockScaleSizeQ || + block_scale_size_k != PipelineProblem::kBlockScaleSizeK) + { + throw std::runtime_error( + "sageattn_fwd MakeKargs: block_scale_size_q/k must match kernel " + "TileSageAttnTraits (Problem::kBlockScaleSizeQ/K)"); + } + + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_k = block_scale_size_k; + } + // logits_soft_cap is always disabled + + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); + return kargs; + } + + template + CK_TILE_HOST static std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_q_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, + ck_tile::index_t batch_stride_v_descale, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_k, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q, + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + static_cast(scale_s * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // mask or SageAttnFwdEmptyKargs<0> + {}, // qscale or SageAttnFwdEmptyKargs<1> + {}, // min_seqlen_q or SageAttnFwdEmptyKargs<2> + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTENSOR) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + } + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + using PipelineProblem = typename SageAttnPipeline::Problem; + + if(block_scale_size_q != PipelineProblem::kBlockScaleSizeQ || + block_scale_size_k != PipelineProblem::kBlockScaleSizeK) + { + throw std::runtime_error( + "sageattn_fwd MakeKargs: block_scale_size_q/k must match kernel " + "TileSageAttnTraits (Problem::kBlockScaleSizeQ/K)"); + } + + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_v_descale = batch_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_k = block_scale_size_k; + + kargs.block_scale_seqstart_q_ptr = + reinterpret_cast(block_scale_seqstart_q_ptr); + kargs.block_scale_seqstart_k_ptr = + reinterpret_cast(block_scale_seqstart_k_ptr); + } + // logits_soft_cap is always disabled + if constexpr(kSkipMinSeqlenQ) + { + kargs.min_seqlen_q = min_seqlen_q; + } + + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_, + bool has_padded_seqlen_k = false) + { + // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) + if(has_padded_seqlen_k) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, SageAttnPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, SageAttnPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, SageAttnPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, SageAttnPipeline::kN1), + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + bool has_padded_seqlen_k = false; + + if constexpr(kIsGroupMode) + has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); + + if(has_padded_seqlen_k) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, SageAttnPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, SageAttnPipeline::kN1); + + const index_t i_block = blockIdx.y; // blockIdx.x + const index_t i_nhead = blockIdx.x; // blockIdx.y + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + } + + CK_TILE_HOST static dim3 BlockSize() + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(SageAttnPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * SageAttnPipeline::kM0); + const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * SageAttnPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_o = 0; + long_index_t batch_offset_q_descale = 0; + long_index_t batch_offset_k_descale = 0; + long_index_t batch_offset_v_descale = 0; + + if constexpr(kIsGroupMode) + { + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + // DRAM base offsets use physical starts + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + // BLOCKSCALE, PERWARP, and PERTHREAD all use block_scale_seqstart in group mode + // They differ only in block size: BLOCKSCALE (Q:128, K:128), PERWARP (Q:32, K:64), + // PERTHREAD (Q:4, K:16); see TileSageAttnTraits::kBlockScaleSizeQ/K. + const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch]; + const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch]; + batch_offset_q_descale = bquery_start; + batch_offset_k_descale = bkey_start; + // BLOCKSCALE, PERWARP, and PERTHREAD V all use per-channel scale: batch_stride = + // nhead_k * hdim_v + batch_offset_v_descale = + static_cast(i_batch) * kargs.batch_stride_v_descale; + } + batch_offset_o = query_start * kargs.stride_o; + + // real logical lengths (exclude PAD) + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + } + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) + { + return; + } + } + + // terminate unnecessary blocks earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + batch_offset_q_descale = + static_cast(i_batch) * kargs.batch_stride_q_descale; + batch_offset_k_descale = + static_cast(i_batch) * kargs.batch_stride_k_descale; + batch_offset_v_descale = + static_cast(i_batch) * kargs.batch_stride_v_descale; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + // If cumulative seqlen pointers are provided, override per-batch effective lengths + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = + reinterpret_cast(kargs.q_ptr) + + (static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q) / + ck_tile::numeric_traits::PackedSize; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + (static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k) / + ck_tile::numeric_traits::PackedSize; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + (static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v) / + ck_tile::numeric_traits::PackedSize; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(SageAttnPipeline::kQLoadOnce) + { + return pad_tensor_view(q_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = + make_tile_window(q_dram, + [&]() { + if constexpr(SageAttnPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. + /// Remove following copy capture of the 'i_nhead' if in C++20 + + AttnMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + 0, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return AttnMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = EmptyPositionEncoding{}; + + AttentionVariant variant; + const auto variant_params = [&] { + const float scale_s = [&] { + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTENSOR) + { + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + + return kargs.scale_s * q_descale * k_descale; + } + else + { + return kargs.scale_s; + } + }(); + + // logits_soft_cap is always disabled, use standard attention params + return ck_tile::StandardAttentionParams{mask, scale_s}; + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + auto o_acc_tile = [&]() { + using PipelineProblem = typename SageAttnPipeline::Problem; + + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTENSOR) + { + // TODO - move global load of descale to pipeline + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + + float scale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return make_composes( + ck_tile::saturates{}, + ck_tile::scales>{scale_o}); + else + return ck_tile::scales>{scale_o}; + }(); + return SageAttnPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + identity{}, // s_acc_element_func + scales>{scale_p}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr); + } + else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE) + { + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const float* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const float* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + // BLOCKSCALE: one q_descale per tile (kBlockScaleSizeQ from traits, typically 128) + const index_t idx = i_m0 / PipelineProblem::kBlockScaleSizeQ; + float q_descale = q_descale_ptr[idx]; + + return SageAttnPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + scales(q_descale), // s_acc_element_func + identity{}, // p_compute_element_func - No scaling (done in exp2) + identity{}, // o_acc_element_func - No dequant (canceled by rowsum) + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + nullptr, + k_descale_ptr, + v_descale_ptr); + } + else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + using SageShape = typename SageAttnPipeline::BlockSageAttnShape; + constexpr index_t kWarpSize = get_warp_size(); + constexpr index_t kGemm0MPerWarp = SageShape::Gemm0WarpTile::at(number<0>{}); + constexpr index_t kNumWarps = SageShape::NumWarps; + + static_assert(kWarpSize == 64, "kWarpSize must be 64"); + static_assert(SageAttnPipeline::kM0 == kGemm0MPerWarp * kNumWarps, + "PERWARP/PERTHREAD q_descale: kM0 must equal " + "Gemm0 MPerWarp * NumWarps"); + static_assert(kWarpSize % kGemm0MPerWarp == 0, + "PERWARP/PERTHREAD: warp_size must be divisible by Gemm0 MPerWarp"); + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP) + { + static_assert(kGemm0MPerWarp == PipelineProblem::kBlockScaleSizeQ, + "PERWARP: Gemm0 MPerWarp must match kBlockScaleSizeQ " + "(one q scale per warp with shared q_descale scalar)"); + } + static_assert(SageShape::Gemm0WarpTile::at(number<0>{}) == 32 && + SageShape::Gemm0WarpTile::at(number<1>{}) == 32, + "PERWARP/PERTHREAD q_descale assumes Gemm0 warp tile MxN is 32x32"); + + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const float* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const float* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + // Q row from tile origin i_m0 + wave M strip + lane; clamp q_scale_idx to the + // last scale block for this seqlen_q (e.g. seqlen_q=129, S=32: ceil(129/32)=5 + // blocks, indices 0..4; row 128 -> 128/32=4; padding -> min(raw_idx, max_idx)). + constexpr index_t kBlockSq = PipelineProblem::kBlockScaleSizeQ; + const index_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / kWarpSize); + const index_t q_row_raw = + i_m0 + wave_id * kGemm0MPerWarp + threadIdx.x % kGemm0MPerWarp; + const index_t q_scale_idx_raw = ck_tile::integer_divide_floor(q_row_raw, kBlockSq); + const index_t max_q_scale_idx = + kargs.seqlen_q > 0 ? ck_tile::integer_divide_ceil(kargs.seqlen_q, kBlockSq) - 1 + : 0; + const index_t q_scale_idx = + q_scale_idx_raw < max_q_scale_idx ? q_scale_idx_raw : max_q_scale_idx; + const float q_descale = q_descale_ptr[q_scale_idx]; + + return SageAttnPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + identity{}, // s_acc_element_func (K/V scales in pipeline) + identity{}, // p_compute_element_func + identity{}, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + nullptr, + k_descale_ptr, + v_descale_ptr, + q_descale); + } + else + { + return SageAttnPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = make_tile_window( + o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp new file mode 100644 index 0000000000..4cf54cabb4 --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp @@ -0,0 +1,29 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockSageAttnPipelineEnum +{ + QRKSVS = 0, + QRKSVS_ASYNC, +}; + +template +struct BlockSageAttnPipelineEnumToStr; + +template <> +struct BlockSageAttnPipelineEnumToStr +{ + static constexpr const char* name = "qr"; +}; +template <> +struct BlockSageAttnPipelineEnumToStr +{ + static constexpr const char* name = "qr_async"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp new file mode 100644 index 0000000000..67d70f501f --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp @@ -0,0 +1,60 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" + +namespace ck_tile { + +template +struct BlockSageAttnPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockSageAttnShape = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using AttnMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = BlockSageAttnShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = BlockSageAttnShape::NumGemm1Warps; + static constexpr index_t kBlockSize = BlockSageAttnShape::NumWarps * get_warp_size(); + + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; + static constexpr auto QScaleEnum = Traits::QScaleEnum; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + + /// Must match host scale tensor layout (same values as TileSageAttnTraits for Sage kernels). + static constexpr index_t kBlockScaleSizeQ = Traits::kBlockScaleSizeQ; + static constexpr index_t kBlockScaleSizeK = Traits::kBlockScaleSizeK; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp new file mode 100644 index 0000000000..75eaf22295 --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp @@ -0,0 +1,861 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockSageAttentionPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using QGemmDataType = SageAttnQKGemmQDataType; + using KDataType = remove_cvref_t; + using KLdsDataType = SageAttnQKGemmKDataType; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + // fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V; + // FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases. + static_assert(std::is_same_v, + "SageAttention pipeline requires PDataType == VDataType for the PV gemm"); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "SageAttention pipeline requires PDataType = fp8_t"); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "SageAttention pipeline requires VDataType = fp8_t"); + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using AttnMask = remove_cvref_t; + + using BlockSageAttnShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockSageAttnShape::kM0; + static constexpr index_t kN0 = BlockSageAttnShape::kN0; + static constexpr index_t kK0 = BlockSageAttnShape::kK0; + static constexpr index_t kN1 = BlockSageAttnShape::kN1; + static constexpr index_t kK1 = BlockSageAttnShape::kK1; + static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + + static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read + static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) + + // FP8 softmax shift constants to map softmax output into representable FP8 range + // OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875) + // Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range + // FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875) + // Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim <= 32) + { + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + AttnMask mask, + PositionEncoding /*position_encoding*/, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + [[maybe_unused]] const float* q_descale_ptr = nullptr, + const float* k_descale_ptr = nullptr, + const float* v_descale_ptr = nullptr, + [[maybe_unused]] float q_descale_value = 1.0f) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KLdsDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window_reg = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + auto q = load_tile(q_dram_window_reg); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = + std::conditional_t, + SaccBlockTileType, + decltype(cast_tile(SaccBlockTileType{}))>; + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } + const auto q_origin = q_dram_block_window_tmp.get_window_origin(); + + const auto tile_range_result = [&mask, &q_origin]() { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(start, end); + }(); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{}); + const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0; + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(AttnMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {kv_load_start, 0}); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, kv_load_start}, + Policy::template MakeVDramTileDistribution()); + + auto q_tile = [&]() { + if constexpr(std::is_same_v) + return tile_elementwise_in(q_element_func, q); + else + { + auto q_tile_tmp = make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution()); + constexpr index_t kPackedSize = numeric_traits::PackedSize; + constexpr index_t kUnaryOpSize = 8; + static_assert(std::is_same_v); + static_assert(kPackedSize == 2); + static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() == + decltype(q)::get_thread_buffer_size() * kPackedSize); + static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0); + + using RawQType = typename QDataType::type; + using SrcVectorType = ext_vector_t; + using DstVectorType = ext_vector_t; + constexpr index_t kVecSize = + decltype(q_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize; + static_assert(decltype(q)::get_thread_buffer_size() == + kVecSize * (kUnaryOpSize / kPackedSize)); + + const element_wise::PassThroughPack8 pass_through_pack8{}; + static_for<0, kVecSize, 1>{}([&](auto i) { + pass_through_pack8( + q_tile_tmp.get_thread_buffer().template get_as()(i), + q.get_thread_buffer().template get_as()[i]); + }); + return q_tile_tmp; + } + }(); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + // Use compile-time conditional for group barrier sequence + // (No runtime lambda selection) + auto schedule_gemm0 = [] { + using BlockGemm0 = remove_cvref_t; + constexpr auto WarpGemmConfig = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm0 = remove_cvref_t())>; + constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>(); + constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>(); + constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM; + constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN; + constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK; + constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) * + (kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp); + if constexpr(get_warp_size() == 64 && kQKHeaddim == 256) + { + static_assert(NumMfmaInsts % 8 == 0); + static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read + __builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA + }); + } + }; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}); + static_assert(get_warp_size() % kGemm0MPerWarp == 0); + constexpr index_t kWarpSz = get_warp_size(); + // sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale + // indexing) + index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp; + // main loop + do + { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE) + { + const index_t kv_idx = + (seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK; + k_descale = k_descale_ptr[kv_idx]; + } + constexpr index_t kNumKScalesPW = + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP + ? kN0 / Problem::kBlockScaleSizeK + : 1; + constexpr index_t kNumKScalesPT = + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD + ? kN0 / Problem::kBlockScaleSizeK / 2 + : 1; + float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {}; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP) + { + const index_t kv_idx = + (seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPW; i++) + k_scales_perwarp[i] = k_descale_ptr[kv_idx + i]; + } + float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {}; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + const index_t k_global_start = seqlen_k_start + i_total_loops * kN0; + const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPT; i++) + k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx]; + } + + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + auto s_acc_gemm = SaccBlockTileType{}; + const auto store_k_block_tile_to_lds = [&](const auto& k_block_tile_) { + if constexpr(std::is_same_v) + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile_)); + else + { + auto k_block_tile_tmp = make_static_distributed_tensor( + k_dram_window.get_tile_distribution()); + using KBlockTileType = remove_cvref_t; + constexpr index_t kPackedSize = numeric_traits::PackedSize; + constexpr index_t kUnaryOpSize = 8; + static_assert(std::is_same_v); + static_assert(kPackedSize == 2); + static_assert(decltype(k_block_tile_tmp)::get_thread_buffer_size() == + KBlockTileType::get_thread_buffer_size() * kPackedSize); + static_assert( + decltype(k_block_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0); + + using RawKType = typename KDataType::type; + using SrcVectorType = ext_vector_t; + using DstVectorType = ext_vector_t; + constexpr index_t kVecSize = + decltype(k_block_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize; + static_assert(KBlockTileType::get_thread_buffer_size() == + kVecSize * (kUnaryOpSize / kPackedSize)); + + const element_wise::PassThroughPack8 pass_through_pack8{}; + static_for<0, kVecSize, 1>{}([&](auto i) { + pass_through_pack8( + k_block_tile_tmp.get_thread_buffer().template get_as()( + i), + k_block_tile_.get_thread_buffer().template get_as()[i]); + }); + store_tile(k_lds_window, k_block_tile_tmp); + } + }; + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc_gemm); // initialize C + store_k_block_tile_to_lds(k_block_tile); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc_gemm, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + schedule_gemm0(); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_k_block_tile_to_lds(k_block_tile); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc_gemm, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + schedule_gemm0(); + block_sync_lds(); + + store_k_block_tile_to_lds(k_block_tile); + block_sync_lds(); + + gemm_0(s_acc_gemm, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + schedule_gemm0(); + } + + // Convert GEMM output to SaccDataType for softmax (if needed) + auto s_acc = [&]() { + using GemmDataType = typename decltype(s_acc_gemm)::DataType; + if constexpr(std::is_same_v) + { + return s_acc_gemm; // No conversion needed (e.g., float -> float) + } + else + { + return cast_tile(s_acc_gemm); // Convert (e.g., int32 -> float) + } + }(); + + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + // PERTHREAD: kBlockScaleSizeK=16 + // The s_acc tile distribution is determined by + // WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees + // each thread processes exactly 16 consecutive elements in the K dimension. This + // distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and + // TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local + // elements to K scale indices. + static_assert(Problem::kBlockScaleSizeK == 16, + "PERTHREAD: kBlockScaleSizeK must be 16"); + + // Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB + + // TransposedC This ensures the distribution has 16 consecutive K elements per + // thread + using BlockGemm0 = remove_cvref_t; + constexpr auto WarpGemmCfg = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm0Type = remove_cvref_t())>; + using ExpectedWarpGemmI8 = + WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>; + using ExpectedWarpGemmFp8 = + WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>; + static_assert( + std::is_same_v || + std::is_same_v, + "PERTHREAD requires " + "WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for " + "16 consecutive K elements"); + + constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans(); + float combined_scales_reg[kNumKScalesPT] = {}; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPT; i++) + combined_scales_reg[i] = q_descale_value * k_scales_reg[i]; + sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) { + index_t col_offset = 0; + sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // col_offset counts columns in distributed view + // Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16) + const index_t scale_idx = col_offset >> 4; + s_acc(i_j_idx) *= combined_scales_reg[scale_idx]; + col_offset++; + }); + }); + } + else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP) + { + // PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale + // Distribution: thread_i and thread_(i+32) interleave to cover K dimension + // In each thread's view, every 32 idx1 steps correspond to 64 global K elements + + // Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB + + // TransposedC This ensures each thread has 16 consecutive elements, and warp-level + // grouping is correct + using BlockGemm0 = remove_cvref_t; + constexpr auto WarpGemmCfg = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm0Type = remove_cvref_t())>; + using ExpectedWarpGemmI8 = + WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>; + using ExpectedWarpGemmFp8 = + WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>; + static_assert( + std::is_same_v || + std::is_same_v, + "PERWARP requires " + "WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for " + "correct K element grouping"); + + constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans(); + float combined_scales_reg[kNumKScalesPW] = {}; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPW; i++) + combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i]; + sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) { + index_t col_offset = 0; + sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // col_offset counts columns in distributed view + // When N0=64: each thread has 32 elements; when N0=128: each thread has 64 + // elements Divide by 32 (>>5) to map to K scale groups + // (kBlockScaleSizeK=64) + const index_t scale_idx = col_offset >> 5; + s_acc(i_j_idx) *= combined_scales_reg[scale_idx]; + col_offset++; + }); + }); + } + else + { + // dequant: combine q_descale (in s_acc_element_func) with k_descale + auto s_acc_element_func_ = [&]() { + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + } + // STAGE 2, scale_s, mask, softmax + if constexpr(kPadSeqLenK || AttnMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(AttnMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + // For BLOCKSCALE: precompute (m - shift) once per row + // exp2(s - m + shift) = exp2(s - (m - shift)); pertensor path uses scale_s on s,m + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto m_new = get_validated_m(m[i_idx]); + auto row_max = scale_s * m_new; + const auto tmp = exp2(scale_s * m_old[i_idx] - row_max); + // Update l and rescale o_acc + l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + // For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc + // Apply per-channel v_descale after the loop (before normalization) + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + + } while(++i_total_loops < num_total_loop); + + // Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop, + // before normalization) + if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP || + Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + // Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space + block_sync_lds(); + + // V is col-major, each column (channel) has its own scale + // o_acc shape: [M0, N1] where N1 is hdim_v + // v_descale_ptr points to per-channel scales [hdim_v] + // Load v_descale to LDS for better memory access pattern + // Reuse K/V LDS space (they're no longer needed) + auto v_descale_lds = reinterpret_cast(smem_ptr); + + // Cooperatively load v_descale to LDS + const index_t num_threads = kBlockSize; + for(index_t i = threadIdx.x; i < kN1; i += num_threads) + { + v_descale_lds[i] = v_descale_ptr[i]; + } + block_sync_lds(); + + constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // Get the global tile index for the N1 (channel) dimension + const auto tile_idx = get_x_indices_from_distributed_indices( + o_acc.get_tile_distribution(), i_j_idx); + const index_t channel_idx = tile_idx.at(number<1>{}); + const float v_scale = v_descale_lds[channel_idx]; + o_acc(i_j_idx) *= v_scale; + }); + }); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + // When masking, the denominator can be zero; guard the normalization + // so we do not divide by zero after a fully masked row. + if constexpr(AttnMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + AttnMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + [[maybe_unused]] const float* q_descale_ptr = nullptr, + const float* k_descale_ptr = nullptr, + const float* v_descale_ptr = nullptr, + [[maybe_unused]] float q_descale_value = 1.0f) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + q_descale_ptr, + k_descale_ptr, + v_descale_ptr, + q_descale_value); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 0000000000..a64cc85643 --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,873 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockSageAttentionPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + // fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V; + // FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases. + static_assert(std::is_same_v, + "SageAttention pipeline requires PDataType == VDataType for the PV gemm"); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "SageAttention pipeline requires PDataType = fp8_t"); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "SageAttention pipeline requires VDataType = fp8_t"); + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using AttnMask = remove_cvref_t; + + using BlockSageAttnShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockSageAttnShape::kM0; + static constexpr index_t kN0 = BlockSageAttnShape::kN0; + static constexpr index_t kK0 = BlockSageAttnShape::kK0; + static constexpr index_t kN1 = BlockSageAttnShape::kN1; + static constexpr index_t kK1 = BlockSageAttnShape::kK1; + static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr auto QScaleEnum = Problem::QScaleEnum; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + + // FP8 softmax shift constants to map softmax output into representable FP8 range + // OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875) + // Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range + // FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875) + // Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim <= 32) + { + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + return 2; + } + else if constexpr(kQKHeaddim <= 192) + { + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + AttnMask mask, + PositionEncoding /*position_encoding*/, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + [[maybe_unused]] const float* q_descale_ptr = nullptr, + const float* k_descale_ptr = nullptr, + const float* v_descale_ptr = nullptr, + [[maybe_unused]] float q_descale_value = 1.0f) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = + std::conditional_t, + SaccBlockTileType, + decltype(cast_tile(SaccBlockTileType{}))>; + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin]() { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(start, end); + }(); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{}); + const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0; + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(AttnMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {kv_load_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = bool_constant{}; + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, kv_load_start}, + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}); + static_assert(kGemm0MPerWarp == 32); + constexpr index_t kWarpSz = get_warp_size(); + // sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale + // indexing) + index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp; + // main loop + do + { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE) + { + const index_t kv_idx = + (seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK; + k_descale = k_descale_ptr[kv_idx]; + } + constexpr index_t kNumKScalesPW = + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP + ? kN0 / Problem::kBlockScaleSizeK + : 1; + constexpr index_t kNumKScalesPT = + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD + ? kN0 / Problem::kBlockScaleSizeK / 2 + : 1; + float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {}; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP) + { + const index_t kv_idx = + (seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPW; i++) + k_scales_perwarp[i] = k_descale_ptr[kv_idx + i]; + } + float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {}; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + const index_t k_global_start = seqlen_k_start + i_total_loops * kN0; + const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPT; i++) + k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx]; + } + + // STAGE 1, QK gemm + auto s_acc_gemm = SaccBlockTileType{}; + clear_tile(s_acc_gemm); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc_gemm, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0( + s_acc_gemm, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + } + __builtin_amdgcn_sched_barrier(1); + + // Convert GEMM output to SaccDataType for softmax (if needed) + auto s_acc = [&]() { + using GemmDataType = typename decltype(s_acc_gemm)::DataType; + if constexpr(std::is_same_v) + { + return s_acc_gemm; // No conversion needed (e.g., float -> float) + } + else + { + return cast_tile(s_acc_gemm); // Convert (e.g., int32 -> float) + } + }(); + + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + // PERTHREAD: kBlockScaleSizeK=16 + // The s_acc tile distribution is determined by + // WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees + // each thread processes exactly 16 consecutive elements in the K dimension. This + // distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and + // TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local + // elements to K scale indices. + static_assert(Problem::kBlockScaleSizeK == 16, + "PERTHREAD: kBlockScaleSizeK must be 16"); + + // Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB + + // TransposedC This ensures the distribution has 16 consecutive K elements per + // thread + using BlockGemm0 = remove_cvref_t; + constexpr auto WarpGemmCfg = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm0Type = remove_cvref_t())>; + using ExpectedWarpGemmI8 = + WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>; + using ExpectedWarpGemmFp8 = + WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>; + static_assert( + std::is_same_v || + std::is_same_v, + "PERTHREAD requires " + "WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for " + "16 consecutive K elements"); + + constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans(); + float combined_scales_reg[kNumKScalesPT] = {}; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPT; i++) + combined_scales_reg[i] = q_descale_value * k_scales_reg[i]; + sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) { + index_t col_offset = 0; + sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // col_offset counts columns in distributed view + // Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16) + const index_t scale_idx = col_offset >> 4; + s_acc(i_j_idx) *= combined_scales_reg[scale_idx]; + col_offset++; + }); + }); + } + else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP) + { + // PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale + // Distribution: thread_i and thread_(i+32) interleave to cover K dimension + // In each thread's view, every 32 idx1 steps correspond to 64 global K elements + + // Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB + + // TransposedC This ensures each thread has 16 consecutive elements, and warp-level + // grouping is correct + using BlockGemm0 = remove_cvref_t; + constexpr auto WarpGemmCfg = + BlockGemm0::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm0Type = remove_cvref_t())>; + using ExpectedWarpGemmI8 = + WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>; + using ExpectedWarpGemmFp8 = + WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>; + static_assert( + std::is_same_v || + std::is_same_v, + "PERWARP requires " + "WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for " + "correct K element grouping"); + + constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans(); + float combined_scales_reg[kNumKScalesPW] = {}; +#pragma unroll + for(index_t i = 0; i < kNumKScalesPW; i++) + combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i]; + sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) { + index_t col_offset = 0; + sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // col_offset counts columns in distributed view + // When N0=64: each thread has 32 elements; when N0=128: each thread has 64 + // elements Divide by 32 (>>5) to map to K scale groups + // (kBlockScaleSizeK=64) + const index_t scale_idx = col_offset >> 5; + s_acc(i_j_idx) *= combined_scales_reg[scale_idx]; + col_offset++; + }); + }); + } + else + { + // dequant: combine q_descale (in s_acc_element_func) with k_descale + auto s_acc_element_func_ = [&]() { + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + } + // STAGE 2, scale_s, mask, softmax + // logits_soft_cap is always disabled + if constexpr(kPadSeqLenK || AttnMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store + // Only needed when K tail and V use the same LDS buffer + if constexpr(LdsSeq.at(number{}) == LdsSeq.at(number{})) + { + __builtin_amdgcn_s_barrier(); + } + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(AttnMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + // For BLOCKSCALE: precompute (m - shift) once per row + // exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP || + QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // logits_soft_cap is always disabled + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto m_new = get_validated_m(m[i_idx]); + auto row_max = scale_s * m_new; + const auto tmp = exp2(scale_s * m_old[i_idx] - row_max); + // Update l and rescale o_acc + l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + const auto p = [&]() { +#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN + // For fp32 to fp16, + // impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue, + // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero. + return cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#else + if constexpr(std::is_same_v) + return impl::cast_tile_pkrtz_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); +#endif + }(); + + // STAGE 3, KV gemm + // For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc + // Apply per-channel v_descale after the loop (before normalization) + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + move_tile_window(k_dram_block_window, {kN0, 0}); + + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + + } while(i_total_loops < num_total_loop); + + // Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop, + // before normalization) + if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE || + Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP || + Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD) + { + // Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space + block_sync_lds(); + + // V is col-major, each column (channel) has its own scale + // o_acc shape: [M0, N1] where N1 is hdim_v + // v_descale_ptr points to per-channel scales [hdim_v] + // Load v_descale to LDS for better memory access pattern + // Reuse K/V LDS space (they're no longer needed) + auto v_descale_lds = reinterpret_cast(smem_ptr); + + // Cooperatively load v_descale to LDS + const index_t num_threads = kBlockSize; + for(index_t i = threadIdx.x; i < kN1; i += num_threads) + { + v_descale_lds[i] = v_descale_ptr[i]; + } + block_sync_lds(); + + constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // Get the global tile index for the N1 (channel) dimension + const auto tile_idx = get_x_indices_from_distributed_indices( + o_acc.get_tile_distribution(), i_j_idx); + const index_t channel_idx = tile_idx.at(number<1>{}); + const float v_scale = v_descale_lds[channel_idx]; + o_acc(i_j_idx) *= v_scale; + }); + }); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(AttnMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + AttnMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const float* q_descale_ptr = nullptr, + const float* k_descale_ptr = nullptr, + const float* v_descale_ptr = nullptr, + [[maybe_unused]] float q_descale_value = 1.0f) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + q_descale_ptr, + k_descale_ptr, + v_descale_ptr, + q_descale_value); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 0000000000..1b7a3dae79 --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy = + BlockSageAttnPipelineQRKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp new file mode 100644 index 0000000000..49e03c8c03 --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp @@ -0,0 +1,857 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE static constexpr index_t GetPackedSize() +{ + return numeric_traits>::PackedSize; +} + +template +CK_TILE_HOST_DEVICE static constexpr index_t GetLogicalVectorSize(index_t bytes) +{ + return (bytes / sizeof(remove_cvref_t)) * GetPackedSize(); +} + +template +using SageAttnQKGemmQDataType = + std::conditional_t>, + fp8_t, + remove_cvref_t>; + +template +using SageAttnQKGemmKDataType = + std::conditional_t>, + fp8_t, + remove_cvref_t>; + +template +struct BlockSageAttnPipelineQRCustomPolicy; + +template <> +struct BlockSageAttnPipelineQRCustomPolicy +{ + static constexpr bool QLoadOnce = true; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return 0; + } + + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = GetLogicalVectorSize(16); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeABlockTileDistribution< + Problem::BlockSageAttnShape::kM0, + Problem::BlockSageAttnShape::kSubQKHeaddim>(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using QKGemmQDataType = SageAttnQKGemmQDataType; + using QKGemmKDataType = SageAttnQKGemmKDataType; + // int8 MFMA accumulates to int32, but SaccDataType is float for softmax + using GemmAccDataType = + std::conditional_t<(std::is_same_v || + std::is_same_v) && + (std::is_same_v || + std::is_same_v), + int32_t, + typename Problem::SaccDataType>; + + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockSageAttnShape::Gemm0BlockWarps, + typename Problem::BlockSageAttnShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + if constexpr(get_warp_size() == 64 && std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32); + + // TODO: hard coded here. Otherwise, it produces incorrect results + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else if constexpr(get_warp_size() == 64 && + (std::is_same_v || + std::is_same_v) && + (std::is_same_v || + std::is_same_v)) + { + static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32); + + // Use special int8 MFMA with K iteration (similar to FP8) + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } + else + { + constexpr bool SwizzleA = + Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32; + return WarpGemmDispatcher< + QKGemmQDataType, + QKGemmKDataType, + GemmAccDataType, + Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}), + Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}), + true, // TransposeC + SwizzleA>{}; + } + }(); + + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< + QKGemmQDataType, + QKGemmKDataType, + GemmAccDataType, + typename Problem::BlockSageAttnShape::Gemm0BlockWarps, + decltype(warp_gemm)>; + + if constexpr(1 < Problem::kNumGemm0Warps) + return BlockGemmARegBSmemCRegV2{}; + else + return BlockGemmARegBSmemCRegOneWarpV1{}; + } +}; + +// This pipeline is qkv all located in LDS +template +struct BlockSageAttnPipelineQRKSVSCustomPolicy : BlockSageAttnPipelineQRCustomPolicy +{ + static constexpr bool AsyncCopy = AsyncCopy_; + + static constexpr index_t NumPrefetchK = NumPrefetchK_; + static constexpr index_t NumPrefetchV = NumPrefetchV_; + + static constexpr index_t NumKVLdsBuffers = max(NumPrefetchK, NumPrefetchV); + + using QXPolicy = BlockSageAttnPipelineQRCustomPolicy; + + template + struct LdsBufferSequence + { + static constexpr index_t num_lds_buffers_ = max(k_prefetches_, v_prefetches_); + static constexpr index_t ceil_ = ((v_loops_ - 1) / num_lds_buffers_) * num_lds_buffers_; + + // for qr_ks_vs_async, the Lds buffer assigned to last gemm_1 iteration of V should not + // overlap with the Lds buffers used by first two gemm_0 iterations of K + static constexpr auto Make() + { + // ensure v_loop_-1 is assigned to num_lds_buffers-1 + return transform_sequences( + [&](auto i) { + if(i < k_loops_) + return i % num_lds_buffers_; + else + return ((num_lds_buffers_ - 1) + (i - k_loops_ + ceil_ - (v_loops_ - 1))) % + num_lds_buffers_; + }, + typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{}); + }; + + using type = remove_cvref_t; + }; + + // clang-format off + template<> struct + LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 4, 2> { using type = sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 4> { using type = sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;}; + // clang-format on + + template + CK_TILE_HOST_DEVICE static constexpr auto GetLdsBufferSequence() + { + using BlockSageAttnShape = remove_cvref_t; + + constexpr index_t kN0 = BlockSageAttnShape::kN0; + constexpr index_t kK0 = BlockSageAttnShape::kK0; + constexpr index_t kK1 = BlockSageAttnShape::kK1; + constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim; + + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + return typename LdsBufferSequence::type{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = SageAttnQKGemmKDataType; + return GetLogicalVectorSize(16); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + if constexpr(AsyncCopy) + { +#if defined(__gfx950__) + constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4 +#else + constexpr index_t MaxLoadSizeInBytes = 4; // dword +#endif + + return GetLogicalVectorSize(MaxLoadSizeInBytes); + } + else + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0; + + constexpr index_t MaxVectorSize = GetLogicalVectorSize(16); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + return min(MaxVectorSize, ElemPerThread); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); + + return kMaxVecLoad; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kMaxVecLoad = + min(total_pixels, static_cast(16 / sizeof(VDataType))); + + if constexpr(std::is_same_v) + { + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + + constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (total_pixels / kMinVecLoad); + + return kVecLoad; + } + else + { + return kMaxVecLoad; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType); + return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + if constexpr(!AsyncCopy) + { + return MakeKLdsBlockDescriptor().get_element_space_size(); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0; + constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && + WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + } + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + + // TODO: this is used for non async copy desc. unify in the future + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0; + constexpr index_t kKPack = GetSmemKPackK(); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeKLdsStoreBlockDescriptor(number = number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad); + // constexpr index_t SingleVSize = + // MakeVLdsBlockDescriptor().get_element_space_size(); + constexpr index_t BufferSize = + GetSingleSmemElementSpaceSize(); // max(SingleKSize, SingleVSize); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // num_buffers + number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, + number{}, + number{}, + number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + using VDataType = remove_cvref_t; + constexpr index_t Banks = get_n_lds_banks(); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number()>{}, + number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple( + number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() + { + // TODO: assume Q is in register + // TODO: assume K and V share smem buffers + using KLdsDataType = SageAttnQKGemmKDataType; + constexpr index_t single_smem_size = + GetSingleSmemElementSpaceSize() * sizeof(KLdsDataType); + + return QXPolicy::template GetSmemSizeQ() + single_smem_size * NumKVLdsBuffers; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return GetSmemSizeKV(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + if constexpr(!AsyncCopy) + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0; + + constexpr index_t MaxVectorSize = GetLogicalVectorSize(16); + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + + constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentK(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VLayout = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible + { + static_assert(kNPerBlock % 16 == 0); + constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16; + constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t N2 = 2; + constexpr index_t N1_m = kNPack / N2; + constexpr index_t N0_m = kNPerBlock / kNPack; + constexpr index_t K1 = get_warp_size() / N1_m; + constexpr index_t K2_m = kKPerBlock / K1 / K0; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2, 1>>, // K0, K1 N0 + tuple, sequence<1, 1>>, + sequence<1, 2, 1>, // N0 K2 N2 + sequence<0, 2, 2>>{}); + } + else if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = GetAlignmentV(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0); + + constexpr auto dstr = make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, // N1, N2 K0 + tuple, sequence<2, 0>>, + sequence<1, 2>, // N0 K1 + sequence<0, 1>>{}); + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock) + { + return dstr; + } + else + { + static_assert(kKPerBlock % 16 == 0); + constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16; + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t N2_m = get_warp_size() / K1_m; + constexpr index_t N0_m = kNPerBlock / (N2_m * N1); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, // N1, N2 K1 + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, // K0 N0 K2 + sequence<0, 0, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); + return dstr_m; + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor() + { + // This descriptor only used when V layout is seqlen * hdim + using VLayout = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1; + + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible + { + static_assert(kNPerBlock % 16 == 0); + constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16; + constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t N2 = 2; + constexpr index_t N1_m = kNPack / N2; + constexpr index_t N0_m = kNPerBlock / kNPack; + constexpr index_t K1 = get_warp_size() / N1_m; + constexpr index_t K2_m = kKPerBlock / K1 / K0; + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, // K0, K1 N0 + tuple, sequence<1, 1>>, + sequence<1, 1, 2>, // N0 K2 <-> N2 + sequence<0, 2, 2>>{}); + } + else if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockSageAttnShape::Gemm1BlockWarps, + typename Problem::BlockSageAttnShape::Gemm1WarpTile>>; + + auto warp_gemm = [&]() { + if constexpr(get_warp_size() == 64 && + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}) == 32); + static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}) == 32); + static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}) == 32); + + return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{}; + } + else + { + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}), + true>{}; + } + }(); + + using WarpGemm = remove_cvref_t; + + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + typename Problem::BlockSageAttnShape::Gemm1BlockWarps, + WarpGemm>; + return BlockGemmARegBSmemCRegV2{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp new file mode 100644 index 0000000000..de9c6979e7 --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +using BlockSageAttentionPipelineQRKSVSDefaultPolicy = + BlockSageAttnPipelineQRKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp b/include/ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp new file mode 100644 index 0000000000..1351de94ef --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length() +{ + if constexpr(Headdim == 48) + return 48; + else if constexpr(Headdim == 80) + return 96; + else if constexpr(Headdim == 96) + return 128; + else if constexpr(Headdim == 160) + return 256; + else if constexpr(Headdim == 192) + return 192; + else if constexpr(is_power_of_two_integer(Headdim)) + return Headdim; + else + static_assert(Headdim == 0, + "only Headdim of 48, 96, 160, 192 and power-of-two is supported"); +}; + +template +struct TileSageAttnShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + + static constexpr index_t NumGemm0Warps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{}); + static constexpr index_t NumGemm1Warps = + reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{}); + static_assert(NumGemm1Warps % NumGemm0Warps == 0); + + static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); + + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll + static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim + static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll + static constexpr index_t kQKHeaddim = + BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at + // once (or repeately load Q as a whole tile) + static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + + static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(); + + // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen + static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; + using VLayout = std::conditional_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp b/include/ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp new file mode 100644 index 0000000000..adeac94200 --- /dev/null +++ b/include/ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" +#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" + +namespace ck_tile { + +template +struct TileSageAttnTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr auto QScaleEnum = QScaleEnum_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + + /// Tokens per Q/K descale along seqlen. Fine-to-coarse: PERTHREAD, PERWARP, then 128 for Q + /// (BLOCKSCALE / no_scale / pertensor). K: PERWARP 64, BLOCKSCALE 128, else 128. + static constexpr index_t kBlockScaleSizeQ = + QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 4 + : QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 32 + : 128; + static constexpr index_t kBlockScaleSizeK = + QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 16 + : QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 64 + : 128; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sageattn.hpp b/include/ck_tile/ops/sageattn.hpp new file mode 100644 index 0000000000..759e698a3d --- /dev/null +++ b/include/ck_tile/ops/sageattn.hpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp" +#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp"