mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574)
## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
95
example/ck_tile/49_sageattention/CMakeLists.txt
Normal file
95
example/ck_tile/49_sageattention/CMakeLists.txt
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 arch is supported
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# ====================================================================
|
||||
# SageAttention codegen - only FWD API, minimal instances
|
||||
# ====================================================================
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
)
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG)
|
||||
|
||||
# Only generate FWD API, only supported head dimension (128)
|
||||
# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py
|
||||
set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${SAGEATTN_TARGETS_ARG}
|
||||
--api fwd
|
||||
--optdim 128
|
||||
)
|
||||
|
||||
# Generate list of kernels to build
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS)
|
||||
|
||||
# Generate the kernel instance files
|
||||
add_custom_command(
|
||||
OUTPUT ${SAGEATTN_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
COMMENT "Generate SageAttention FWD kernels"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Build the kernel instances library
|
||||
add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS})
|
||||
target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Compile options for kernel instances
|
||||
set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS})
|
||||
set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# ====================================================================
|
||||
# SageAttention FWD Example
|
||||
# ====================================================================
|
||||
set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}")
|
||||
|
||||
add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Link with our own minimal instances library (INDEPENDENT from FMHA!)
|
||||
target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances)
|
||||
|
||||
set(SAGEATTN_FWD_COMPILE_OPTIONS)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS})
|
||||
set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
2
example/ck_tile/49_sageattention/codegen/__init__.py
Normal file
2
example/ck_tile/49_sageattention/codegen/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
42
example/ck_tile/49_sageattention/codegen/arch.py
Normal file
42
example/ck_tile/49_sageattention/codegen/arch.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArchTrait:
|
||||
name: str
|
||||
preprocessor_check: str = field(default=None)
|
||||
device_name_check: str = field(default=None)
|
||||
tag: str = field(default=None)
|
||||
filename_suffix: str = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.preprocessor_check is None:
|
||||
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
|
||||
if self.device_name_check is None:
|
||||
object.__setattr__(
|
||||
self,
|
||||
"device_name_check",
|
||||
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
|
||||
)
|
||||
if self.tag is None:
|
||||
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
|
||||
if self.filename_suffix is None:
|
||||
object.__setattr__(self, "filename_suffix", f"_{self.name}")
|
||||
|
||||
|
||||
def get_factories_for_targets(
|
||||
targets: List[str], get_factory: Callable[[str], Any]
|
||||
) -> List[Any]:
|
||||
factories = dict()
|
||||
for target in targets:
|
||||
factory = get_factory(target)
|
||||
factories[factory.arch.name] = factory
|
||||
# Place more specific architectures first
|
||||
factories = sorted(
|
||||
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
|
||||
)
|
||||
return factories
|
||||
4
example/ck_tile/49_sageattention/codegen/cmake_config.py
Normal file
4
example/ck_tile/49_sageattention/codegen/cmake_config.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
103
example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py
Normal file
103
example/ck_tile/49_sageattention/codegen/cpp_symbol_map.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16": "SageAttentionFwdFp16",
|
||||
"bf16": "SageAttentionFwdBf16",
|
||||
"fp8bf16": "SageAttentionFwdFp8Bf16",
|
||||
"i8fp8bf16": "SageAttentionFwdI8Fp8Bf16",
|
||||
"i4fp8bf16": "SageAttentionFwdI4Fp8Bf16",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no": "SageAttnMasks::NoMask",
|
||||
"causal": "SageAttnMasks::CausalMask",
|
||||
"generic": "SageAttnMasks::GenericMask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_map(mask_impl: str):
|
||||
if mask_impl == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask_impl == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_impl(mask: str) -> str:
|
||||
return "simplified" if mask.startswith("s_") else "generic"
|
||||
|
||||
|
||||
def get_mask_cpp_type(mask: str) -> str:
|
||||
return get_mask_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic": "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no": "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask": "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
|
||||
def get_mask_check_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
return get_mask_check_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP",
|
||||
"perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"perwarp": "quant_scale_enum::perwarp",
|
||||
"perthread": "quant_scale_enum::perthread",
|
||||
}
|
||||
|
||||
MODE_MAP = {"batch": "false", "group": "true"}
|
||||
|
||||
LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr": "ck_tile::BlockSageAttentionPipelineQRKSVS",
|
||||
"qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS",
|
||||
"qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
True: "true",
|
||||
False: "false",
|
||||
}
|
||||
2
example/ck_tile/49_sageattention/codegen/ops/__init__.py
Normal file
2
example/ck_tile/49_sageattention/codegen/ops/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
992
example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py
Normal file
992
example/ck_tile/49_sageattention/codegen/ops/sageattn_fwd.py
Normal file
@@ -0,0 +1,992 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import copy
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, ClassVar, Iterable, List, Optional, Tuple
|
||||
|
||||
from codegen.arch import ArchTrait, get_factories_for_targets
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
MODE_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
get_mask_map,
|
||||
get_mask_cpp_type,
|
||||
get_mask_cpp_check_expr,
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8bf16": 8,
|
||||
"i8fp8bf16": 8,
|
||||
"i4fp8bf16": 4,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32: 32,
|
||||
48: 48,
|
||||
64: 64,
|
||||
80: 96,
|
||||
96: 128,
|
||||
128: 128,
|
||||
192: 192,
|
||||
256: 256,
|
||||
}
|
||||
|
||||
SAGEATTN_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "sageattn_fwd.hpp"
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_KERNEL_BODY_TEMPLATE = """
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
|
||||
using sageattn_dtype = {F_dtype};
|
||||
|
||||
using sageattn_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using sageattn_shape = ck_tile::TileSageAttnShape<sageattn_block_tile,
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using sageattn_traits = ck_tile::TileSageAttnTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using sageattn_variant = ck_tile::ComposedAttention<false * ck_tile::LOGITS_SOFT_CAP, true>;
|
||||
|
||||
using sageattn_mask_type = {F_mask};
|
||||
|
||||
using sageattn_pipeline_problem = ck_tile::BlockSageAttnPipelineProblem<
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::QDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::KDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::VDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::SMPLComputeDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::PDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
|
||||
sageattn_shape,
|
||||
{F_mode},
|
||||
sageattn_variant,
|
||||
sageattn_mask_type,
|
||||
sageattn_traits>;
|
||||
|
||||
using sageattn_pipeline = {F_pipeline}<
|
||||
sageattn_pipeline_problem>;
|
||||
|
||||
using sageattn_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename SageAttentionFwdTypeConfig<sageattn_dtype>::OaccDataType,
|
||||
typename SageAttentionFwdTypeConfig<sageattn_dtype>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using sageattn_kernel = {F_kernel}<sageattn_pipeline, sageattn_epilogue>;
|
||||
|
||||
|
||||
using trait = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, sageattn_mask_type, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
template<>
|
||||
float sageattn_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, sageattn_fwd_args a)
|
||||
{{
|
||||
using k_ = sageattn_kernel;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", {F_kname}" << std::flush;
|
||||
auto [kargs, grids] = {F_kargs_creator}<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_FILENAME = "sageattn_fwd_api.cpp"
|
||||
SAGEATTN_FWD_API_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "sageattn_fwd.hpp"
|
||||
|
||||
namespace {
|
||||
bool get_num_cus(unsigned& num_cus) {
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}
|
||||
} // namespace
|
||||
"""
|
||||
SAGEATTN_FWD_API_FUNC_TEMPLATE = """
|
||||
namespace {{
|
||||
float {F_func_name}([[maybe_unused]] sageattn_fwd_traits t, [[maybe_unused]] sageattn_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if(!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
}} // namespace
|
||||
"""
|
||||
SAGEATTN_FWD_API_FOOTER_TEMPLATE = """
|
||||
// Public API entry point - unified for SageAttention
|
||||
float sageattn_fwd(sageattn_fwd_traits traits, sageattn_fwd_args args, const ck_tile::stream_config& config) {
|
||||
return sageattn_fwd_impl(traits, args, config);
|
||||
}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_DTYPE = """{F_if}(t.data_type.compare(\"{F_dtype}\") == 0) {{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
SAGEATTN_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = sageattn_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return sageattn_fwd_<trait_, {F_arch.tag}>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return "true"
|
||||
else:
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdApiTrait:
|
||||
arch: ArchTrait
|
||||
pipeline_tag: str
|
||||
# sync with sageattn_fwd_traits<>, to generate fallback calls
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
mask: str
|
||||
qscale: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
skip: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.mask}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bm0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
# F_dvpad="f": Causal mask requires hdim_v <= kN1 (num_tile_n1 == 1 for tile reversal)
|
||||
# Non-causal requires hdim_v % kN1 == 0 (epilogue writes full tiles)
|
||||
if self.mask == "causal":
|
||||
return f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v <= {self.bn1})"
|
||||
else:
|
||||
return (
|
||||
f"(a.hdim_v % {bk0submax} == 0) && (a.hdim_v % {self.bn1} == 0)"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdPipeline:
|
||||
tag: str
|
||||
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_qscale: str # no/pertensor/blockscale/perwarp/perthread
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
n += "_npad"
|
||||
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
if self.F_skip == "t":
|
||||
n += "_skip"
|
||||
else:
|
||||
n += "_nskip"
|
||||
|
||||
if self.F_qscale != "no":
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class SageAttnFwdApiPool:
|
||||
def __init__(self):
|
||||
self.pool = OrderedDict()
|
||||
|
||||
def register_traits(self, trait: SageAttnFwdApiTrait) -> None:
|
||||
hdim = trait.hdim, trait.bn1
|
||||
ts = (
|
||||
self.pool.setdefault(trait.arch, OrderedDict())
|
||||
.setdefault(trait.dtype, OrderedDict())
|
||||
.setdefault(hdim, [])
|
||||
)
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
def get_num_traits(
|
||||
self, filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None
|
||||
) -> int:
|
||||
if filter_fn is None:
|
||||
|
||||
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
|
||||
return True
|
||||
|
||||
filter_fn = accept_all
|
||||
|
||||
return sum(
|
||||
sum(1 for trait in pool_by_hdim if filter_fn(trait))
|
||||
for pool_by_arch in self.pool.values()
|
||||
for pool_by_dtype in pool_by_arch.values()
|
||||
for pool_by_hdim in pool_by_dtype.values()
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
func_name,
|
||||
filter_fn: Optional[Callable[[SageAttnFwdApiTrait], bool]] = None,
|
||||
) -> str:
|
||||
if filter_fn is None:
|
||||
|
||||
def accept_all(trait: SageAttnFwdApiTrait) -> bool:
|
||||
return True
|
||||
|
||||
filter_fn = accept_all
|
||||
|
||||
def has_traits(node) -> bool:
|
||||
"""Recursively traverse nested OrderedDicts and lists to determine if any SageAttnFwdApiTrait satisfies filter_fn()."""
|
||||
if isinstance(node, list):
|
||||
return any(filter_fn(elem) for elem in node)
|
||||
elif isinstance(node, OrderedDict):
|
||||
return any(has_traits(val) for val in node.values())
|
||||
return False
|
||||
|
||||
per_arch = str()
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(
|
||||
item for item in self.pool.items() if has_traits(item[1])
|
||||
):
|
||||
per_dtypes = str()
|
||||
for i_dtype, (dtype, pool_by_dtype) in enumerate(
|
||||
item for item in pool_by_arch.items() if has_traits(item[1])
|
||||
):
|
||||
per_hdim_case = str()
|
||||
for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate(
|
||||
item for item in pool_by_dtype.items() if has_traits(item[1])
|
||||
):
|
||||
max_bm0 = max(
|
||||
(t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0
|
||||
)
|
||||
inners = str()
|
||||
for i_trait, trait in enumerate(
|
||||
[trait for trait in pool_by_hdim if filter_fn(trait)]
|
||||
):
|
||||
inners += SAGEATTN_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_mask=get_mask_cpp_type(trait.mask),
|
||||
F_mask_check=get_mask_cpp_check_expr(trait.mask),
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
|
||||
F_qscale=QSCALE_MAP[trait.qscale],
|
||||
F_scheck=trait.scheck,
|
||||
F_seqtune=trait.seqtune(max_bm0),
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
per_hdim_case += SAGEATTN_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_(i_hdim),
|
||||
F_hdim=hdim,
|
||||
F_hdim_v=hdim_v,
|
||||
F_inner_dispatch=indent(inners),
|
||||
)
|
||||
per_dtypes += SAGEATTN_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
|
||||
)
|
||||
per_arch += SAGEATTN_FWD_API_PER_ARCH.format(
|
||||
F_if=if_(i_arch),
|
||||
F_arch=arch,
|
||||
F_dtype_case=indent(per_dtypes),
|
||||
)
|
||||
return SAGEATTN_FWD_API_FUNC_TEMPLATE.format(
|
||||
F_func_name=func_name, F_dispatch=indent(per_arch)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdTileSize:
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageAttnFwdKernel:
|
||||
F_arch: ArchTrait
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: SageAttnFwdTileSize
|
||||
F_pipeline: SageAttnFwdPipeline
|
||||
|
||||
_KERNEL_HEADER: ClassVar[str] = SAGEATTN_FWD_KERNEL_HEADER
|
||||
_KERNEL_BODY_TEMPLATE: ClassVar[str] = SAGEATTN_FWD_KERNEL_BODY_TEMPLATE
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kernel_class_name(cls, pipeline_tag):
|
||||
return "ck_tile::SageAttnFwdKernel"
|
||||
|
||||
@classmethod
|
||||
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
|
||||
return "sageattn_fwd_create_kargs_and_grids"
|
||||
|
||||
def render(self) -> str:
|
||||
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
|
||||
F_kname=self.name,
|
||||
F_arch=self.F_arch,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_cpp_type(self.F_pipeline.F_mask),
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag),
|
||||
F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag),
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"sageattn_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
|
||||
|
||||
def api_trait(self) -> SageAttnFwdApiTrait:
|
||||
return SageAttnFwdApiTrait(
|
||||
arch=self.F_arch,
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
qscale=self.F_pipeline.F_qscale,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemContext:
|
||||
dtype: str
|
||||
mode: str
|
||||
hdim: int
|
||||
hdim_v: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelContext:
|
||||
tile: SageAttnFwdTileSize
|
||||
pipeline: SageAttnFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
|
||||
CompatibilityRule = Callable[[ProblemContext, KernelContext], bool]
|
||||
|
||||
|
||||
def is_compatible(
|
||||
problem_ctx: ProblemContext,
|
||||
kernel_ctx: KernelContext,
|
||||
rules: Iterable[CompatibilityRule],
|
||||
) -> bool:
|
||||
return all(rule(problem_ctx, kernel_ctx) for rule in rules)
|
||||
|
||||
|
||||
def create_kernel(
|
||||
arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> SageAttnFwdKernel:
|
||||
return SageAttnFwdKernel(
|
||||
F_arch=arch,
|
||||
F_dtype=problem_ctx.dtype,
|
||||
F_mode=problem_ctx.mode,
|
||||
F_hdim=problem_ctx.hdim,
|
||||
F_tile=kernel_ctx.tile,
|
||||
F_pipeline=kernel_ctx.pipeline,
|
||||
)
|
||||
|
||||
|
||||
class CompatibilityRuleFactory:
|
||||
@staticmethod
|
||||
def get_rules() -> List[CompatibilityRule]:
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
if problem_ctx.mode == "group":
|
||||
if (
|
||||
kernel_ctx.pipeline.F_spad != "t"
|
||||
or kernel_ctx.pipeline.F_skpad != "t"
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
return [check_mode]
|
||||
|
||||
|
||||
class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = CompatibilityRuleFactory.get_rules()
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
|
||||
pass
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
arch = ArchTrait(
|
||||
"gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)"
|
||||
)
|
||||
|
||||
# Note: fp16 is not supported by SageAttention (only bf16 + fp8/int quantization)
|
||||
_DT_BF16 = ("bf16",)
|
||||
_DT_FP8BF16 = ("fp8bf16",)
|
||||
_DT_I8FP8BF16 = ("i8fp8bf16",)
|
||||
_DT_I4FP8BF16 = ("i4fp8bf16",)
|
||||
|
||||
@classmethod
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return cls._DT_BF16 + cls._DT_FP8BF16 + cls._DT_I8FP8BF16 + cls._DT_I4FP8BF16
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if dtype in cls._DT_BF16:
|
||||
return {
|
||||
(128, 128) : [SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# gfx9 (non-gfx950): only F_bn0=64; F_bn0=128 variant is gfx950-only (see Gfx950).
|
||||
return {
|
||||
(128, 128): [
|
||||
SageAttnFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype={dtype}")
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@classmethod
|
||||
def get_pipelines(
|
||||
cls, dtype, hdim, hdim_v, receipt, mask_impl
|
||||
) -> List[SageAttnFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let "t" padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in cls._DT_BF16:
|
||||
qscale = "no"
|
||||
skip = "f" # skip: only false
|
||||
for mask, vlayout in itertools.product(
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["row", "col"],
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "f", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
elif (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# no need lse kernels
|
||||
skip = "f" # skip: only false
|
||||
for mask, qscale, vlayout in itertools.product(
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no", "pertensor", "blockscale", "perwarp", "perthread"],
|
||||
["row", "col"], # Support both row and col major layouts
|
||||
):
|
||||
if dtype in cls._DT_I4FP8BF16:
|
||||
# int4 only uses sync pipeline (qr), pad_d="f" because packed types
|
||||
# require alignment >= PackedSize which conflicts with kPadHeadDimQ=true
|
||||
# forcing alignment to 1. Safe since hdim always matches tile size.
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
elif hdim == 64:
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "f", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr", vlayout, "t", "t", "f", "f", qscale, mask, skip)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "f", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
pipelines.append(SageAttnFwdPipeline("qr_async", vlayout, "t", "t", "t", "t", qscale, mask, skip)) # fmt: skip
|
||||
|
||||
# Packed types (int4) cannot use head-dim padding: the tile_window infrastructure
|
||||
# forces alignment=1 when padding is enabled, but packed types need alignment >= PackedSize.
|
||||
if dtype in cls._DT_I4FP8BF16:
|
||||
for p in pipelines:
|
||||
assert p.F_dpad == "f", (
|
||||
f"int4 dtype '{dtype}' requires pad_d=false, got '{p.F_dpad}'"
|
||||
)
|
||||
assert p.F_dvpad == "f", (
|
||||
f"int4 dtype '{dtype}' requires pad_dv=false, got '{p.F_dvpad}'"
|
||||
)
|
||||
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx950(
|
||||
KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950
|
||||
):
|
||||
arch = ArchTrait("gfx950")
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if (
|
||||
dtype in cls._DT_FP8BF16
|
||||
or dtype in cls._DT_I8FP8BF16
|
||||
or dtype in cls._DT_I4FP8BF16
|
||||
):
|
||||
# gfx950 fp8-family: F_bn0=128 tile only (gfx9 uses F_bn0=64 only).
|
||||
return {
|
||||
(128, 128): [
|
||||
SageAttnFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), # fmt: skip
|
||||
],
|
||||
}
|
||||
return super().get_hdim_tile_size_dict(dtype)
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9):
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
if dtype in cls._DT_BF16:
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(0, SageAttnFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
|
||||
|
||||
def get_factory(target: str):
|
||||
if os.environ.get("CK_TILE_SAGEATTN_FWD_CUSTOM_FACTORY", "0") == "1":
|
||||
return CustomFactory
|
||||
|
||||
# Place more specific architectures first
|
||||
|
||||
if target.startswith("gfx950"):
|
||||
return KernelComponentFactoryGfx950
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
raise Exception(f"Unsupported device target {target}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Product:
|
||||
name: str
|
||||
rule: CompatibilityRule
|
||||
|
||||
def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
return self.rule(problem_ctx, kernel_ctx)
|
||||
|
||||
|
||||
def get_product(receipt: int) -> Product:
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
# bf16 (no quantization) should not have qscale
|
||||
if problem_ctx.dtype == "bf16":
|
||||
if kernel_ctx.pipeline.F_qscale != "no":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return Product(name="All tiles", rule=fit)
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[SageAttnFwdApiPool, List[SageAttnFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = SageAttnFwdApiPool()
|
||||
|
||||
factories = get_factories_for_targets(targets, get_factory)
|
||||
|
||||
for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()):
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
# for hdim_str, mode, mask, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(
|
||||
d.items(), MODE_MAP.keys()
|
||||
):
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
for tile, next_tile in zip(tiles, tiles[1:]):
|
||||
assert next_tile.F_bm0 >= tile.F_bm0, (
|
||||
"Tiles must be ordered by increasing bm0"
|
||||
)
|
||||
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
problem_ctx = ProblemContext(
|
||||
dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v
|
||||
)
|
||||
kernel_ctx = KernelContext(
|
||||
tile=tile, pipeline=pipeline, mask_impl=mask_impl
|
||||
)
|
||||
rules = factory.get_rules()
|
||||
product = get_product(receipt)
|
||||
|
||||
if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]):
|
||||
continue
|
||||
|
||||
k = create_kernel(factory.arch, problem_ctx, kernel_ctx)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: SageAttnFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.render())
|
||||
|
||||
|
||||
def write_fwd_api(
|
||||
api_pool: SageAttnFwdApiPool,
|
||||
autogen_dir: Path,
|
||||
) -> None:
|
||||
content = "".join(
|
||||
[
|
||||
SAGEATTN_FWD_API_HEADER,
|
||||
api_pool.render("sageattn_fwd_impl"),
|
||||
SAGEATTN_FWD_API_FOOTER_TEMPLATE,
|
||||
]
|
||||
)
|
||||
update_file(autogen_dir / SAGEATTN_FWD_API_FILENAME, content)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
file_path: Path,
|
||||
kernel_filter: str,
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(
|
||||
targets, kernel_filter, receipt, optdim_list, mask_impl
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write(
|
||||
(file_path.parent / GEN_DIR / SAGEATTN_FWD_API_FILENAME).as_posix() + "\n"
|
||||
)
|
||||
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
70
example/ck_tile/49_sageattention/codegen/utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# generate kernel instances to speed up compilation
|
||||
import dataclasses
|
||||
import os.path as path
|
||||
import textwrap
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def indent(code: str, indent: str = " ") -> str:
|
||||
return textwrap.indent(code, indent)
|
||||
|
||||
|
||||
def if_(i: int) -> str:
|
||||
return "if" if i == 0 else "else if"
|
||||
|
||||
|
||||
def check_duplicates_and_paddings(traits, trait):
|
||||
"""Check
|
||||
* if the traits list does not contain a trait with the same parameters;
|
||||
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
|
||||
for example, f, _t_, f, t cannot be before f, _f_, f, t.
|
||||
"""
|
||||
|
||||
fields = [f.name for f in dataclasses.fields(trait)]
|
||||
pad_fields = [f for f in fields if "pad" in f]
|
||||
non_pad_fields = [f for f in fields if "pad" not in f]
|
||||
for prev_trait in traits:
|
||||
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
|
||||
continue
|
||||
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
|
||||
raise Exception(f"Duplicate found {trait}")
|
||||
# Check if the previous kernel can be incorrectly used before the current one
|
||||
# for example, f, _t_, f, t cannot be before f, _f_, f, t
|
||||
is_prev_more_restrictive = False
|
||||
is_curr_more_restrictive = False
|
||||
for f in pad_fields:
|
||||
prev_pad = getattr(prev_trait, f)
|
||||
pad = getattr(trait, f)
|
||||
if isinstance(prev_pad, str):
|
||||
prev_pad = 1000000 if prev_pad == "f" else 1
|
||||
pad = 1000000 if pad == "f" else 1
|
||||
elif isinstance(prev_pad, int):
|
||||
prev_pad = 1000000 if prev_pad == 0 else prev_pad
|
||||
pad = 1000000 if pad == 0 else pad
|
||||
else:
|
||||
assert False
|
||||
if prev_pad < pad:
|
||||
is_prev_more_restrictive = True
|
||||
elif prev_pad > pad:
|
||||
is_curr_more_restrictive = True
|
||||
if is_prev_more_restrictive and not is_curr_more_restrictive:
|
||||
raise Exception(
|
||||
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
|
||||
)
|
||||
202
example/ck_tile/49_sageattention/example_sageattn_fwd.cpp
Normal file
202
example/ck_tile/49_sageattention/example_sageattn_fwd.cpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "sageattn_fwd.hpp"
|
||||
#include "sageattn_fwd_runner.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k (including new key/value), -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_qpad",
|
||||
"-1",
|
||||
"seqlen_q stride between 2 batches (group-mode optional).\n"
|
||||
"Provide positive strides per-batch to simulate physical padding on Q.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed, same as xformer kv_padding,\n"
|
||||
"must be greater than or equal to s_k")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("qscale",
|
||||
"n",
|
||||
"n or 0, no scale\n"
|
||||
"pt or 1, per-tensor scale\n"
|
||||
"bs or 2, block scale (Q:128, KV:128)\n"
|
||||
"pw or 3, per-warp scale (Q:32, KV:64)\n"
|
||||
"pth or 4, per-thread scale (Q:4, KV:16)\n")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("prec",
|
||||
"fp8bf16",
|
||||
"Primary: fp8bf16, i8fp8bf16, i4fp8bf16. Also bf16 (keep): pipeline validation "
|
||||
"with qscale=n (no quant); not the quantized Sage product path.")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float"
|
||||
"\n tf or 3 - uniform random float, min max is the max of the type\n")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "sageattn_fwd.json", "json file name to dump results")
|
||||
.insert("q_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
auto run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
|
||||
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
|
||||
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
float scale_s = arg_parser.get_float("scale_s");
|
||||
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
|
||||
std::string qscale_str = arg_parser.get_str("qscale");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
auto json = arg_parser.get_int("json") == 1
|
||||
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
|
||||
: std::nullopt;
|
||||
|
||||
return sageattn_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_qpads,
|
||||
seqlen_kpads,
|
||||
q_eff_lens_per_batch,
|
||||
kv_eff_lens_per_batch,
|
||||
i_perm,
|
||||
o_perm,
|
||||
scale_s,
|
||||
is_v_rowmajor,
|
||||
mask_str,
|
||||
qscale_str,
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run<SageAttentionFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<SageAttentionFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "i8fp8bf16")
|
||||
{
|
||||
return run<SageAttentionFwdI8Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "i4fp8bf16")
|
||||
{
|
||||
return run<SageAttentionFwdI4Fp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::invalid_argument& e)
|
||||
{
|
||||
std::cerr << "Invalid argument: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
173
example/ck_tile/49_sageattention/generate.py
Normal file
173
example/ck_tile/49_sageattention/generate.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
# Strip "sageattn_" so module sageattn_fwd registers as CLI key "fwd".
|
||||
unwanted_prefix = "sageattn_"
|
||||
handlers = dict(
|
||||
[
|
||||
(
|
||||
(
|
||||
op.__name__[len(unwanted_prefix) :]
|
||||
if op.__name__.startswith(unwanted_prefix)
|
||||
else op.__name__
|
||||
),
|
||||
(op.list_blobs, op.write_blobs),
|
||||
)
|
||||
for op in ops
|
||||
]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
|
||||
def write_blobs(
|
||||
targets: List[str],
|
||||
output_dir: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir) / GEN_DIR
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(
|
||||
targets: List[str],
|
||||
output_file: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="Generate SageAttention CK_tile kernel/API blobs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--targets",
|
||||
default="gfx9,gfx950",
|
||||
required=False,
|
||||
help="list of GPU targets, separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--api",
|
||||
default="fwd",
|
||||
required=False,
|
||||
help="Codegen API key(s), comma-separated (e.g. fwd -> module codegen.ops.sageattn_fwd).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="write all the blobs into a directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
|
||||
)
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default="",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="Codegen receipt index. SageAttention forward currently uses receipt 0 only; "
|
||||
"the value is passed through to ops (see get_product in sageattn_fwd.py).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optdim",
|
||||
default="-1",
|
||||
required=False,
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice. "
|
||||
"e.g. --optdim=32,64,128,256",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
targets = args.targets.split(",")
|
||||
api_list = args.api.split(",")
|
||||
filter_list = args.filter.split(",")
|
||||
filter_list.extend([""] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(
|
||||
targets,
|
||||
args.list_blobs,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
else:
|
||||
write_blobs(
|
||||
targets,
|
||||
args.output_dir,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
169
example/ck_tile/49_sageattention/mask.hpp
Normal file
169
example/ck_tile/49_sageattention/mask.hpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "t(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "b(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
tmp.seqlen_q = seqlen_q;
|
||||
tmp.seqlen_k = seqlen_k;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = std::stoi(v);
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, 0, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
}
|
||||
else if(t == "t" || t == "b" || t == "g")
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, 0, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, 0, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.type = mask_enum::window_generic;
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0")
|
||||
{
|
||||
tmp.type = mask_enum::no_mask;
|
||||
}
|
||||
else if(str == "1" || str == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else if(str == "2" || str == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
std::size_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
return static_cast<std::size_t>(seqlen_q) * seqlen_k;
|
||||
std::size_t area = 0;
|
||||
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
||||
{
|
||||
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
||||
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
||||
if(x_end > x_start)
|
||||
{
|
||||
area += (x_end - x_start);
|
||||
}
|
||||
}
|
||||
return area;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
74
example/ck_tile/49_sageattention/quant.hpp
Normal file
74
example/ck_tile/49_sageattention/quant.hpp
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
|
||||
// keep sync with BlockSageAttentionQuantScaleEnum
|
||||
enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
perwarp = 3,
|
||||
perthread = 4,
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
{
|
||||
quant_scale_enum type;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == quant_scale_enum::no_scale)
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
else if(type == quant_scale_enum::perwarp)
|
||||
os << "pw";
|
||||
else if(type == quant_scale_enum::perthread)
|
||||
os << "pth";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
{
|
||||
quant_scale_info info{quant_scale_enum::no_scale};
|
||||
if(str == "n" || str == "0")
|
||||
{
|
||||
info.type = quant_scale_enum::no_scale;
|
||||
}
|
||||
else if(str == "pt" || str == "1")
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else if(str == "bs" || str == "2")
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else if(str == "pw" || str == "3")
|
||||
{
|
||||
info.type = quant_scale_enum::perwarp;
|
||||
}
|
||||
else if(str == "pth" || str == "4")
|
||||
{
|
||||
info.type = quant_scale_enum::perthread;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
|
||||
{
|
||||
qsi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
384
example/ck_tile/49_sageattention/sageattn_fwd.hpp
Normal file
384
example/ck_tile/49_sageattention/sageattn_fwd.hpp
Normal file
@@ -0,0 +1,384 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/sageattn.hpp"
|
||||
|
||||
#include "mask.hpp"
|
||||
#include "quant.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// SageAttention data type configs (must match codegen FWD_DTYPE_MAP + SageAttentionFwdTypeConfig)
|
||||
struct SageAttentionFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdI8Fp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
struct SageAttentionFwdI4Fp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct SageAttentionFwdTypeConfig;
|
||||
|
||||
// fp16/bf16 are not Sage product dtypes; bf16 is intentionally kept in tile_example_sageattn_fwd
|
||||
// for pipeline validation with qscale=n (no quant).
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using ScaleType = float; // scale type for quantized inputs
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using ScaleType = float; // scale type for quantized inputs
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdFp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using ScaleType = float; // scale type for quantized inputs
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdI8Fp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::int8_t;
|
||||
using KDataType = ck_tile::int8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using ScaleType = float; // scale type for Q and K
|
||||
using SaccDataType = float; // Keep as float for softmax computation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // P in FP8 for 2nd gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SageAttentionFwdTypeConfig<SageAttentionFwdI4Fp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::pk_int4_t;
|
||||
using KDataType = ck_tile::pk_int4_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using ScaleType = float;
|
||||
using SaccDataType = float;
|
||||
using SMPLComputeDataType = float;
|
||||
using PDataType = ck_tile::fp8_t;
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct SageAttnMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct sageattn_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* q_descale_ptr;
|
||||
const void* k_descale_ptr;
|
||||
const void* v_descale_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
|
||||
// MQA/GQA..."]
|
||||
//
|
||||
// With padding:
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
||||
// lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
||||
// sequence. [array size: batch]
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
||||
// exclusive. Use one set, not both.
|
||||
//
|
||||
// Batch mode:
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqstart_* and seqlen_* pointers must be nullptr.
|
||||
//
|
||||
// Without padding:
|
||||
// (Note: Physical length equals logical length)
|
||||
//
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
|
||||
// size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
|
||||
//
|
||||
// Batch mode:
|
||||
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
|
||||
//
|
||||
const void* seqstart_q_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqstart_k_ptr =
|
||||
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
|
||||
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
|
||||
// [batch]. (Used in Group mode with padding)
|
||||
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
// BLOCKSCALE parameters
|
||||
ck_tile::index_t nhead_stride_q_descale = 0;
|
||||
ck_tile::index_t nhead_stride_k_descale = 0;
|
||||
ck_tile::index_t nhead_stride_v_descale = 0;
|
||||
ck_tile::index_t batch_stride_q_descale = 0;
|
||||
ck_tile::index_t batch_stride_k_descale = 0;
|
||||
ck_tile::index_t batch_stride_v_descale = 0;
|
||||
ck_tile::index_t block_scale_size_q = 0;
|
||||
ck_tile::index_t block_scale_size_k = 0;
|
||||
const void* block_scale_seqstart_q_ptr = nullptr;
|
||||
const void* block_scale_seqstart_k_ptr = nullptr;
|
||||
};
|
||||
|
||||
template <typename SageAttnKernel>
|
||||
auto sageattn_fwd_create_kargs_and_grids(sageattn_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(SageAttnKernel::kIsGroupMode)
|
||||
{
|
||||
return SageAttnKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_k,
|
||||
args.block_scale_seqstart_q_ptr,
|
||||
args.block_scale_seqstart_k_ptr,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return SageAttnKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_q_descale,
|
||||
args.batch_stride_k_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_k,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(SageAttnKernel::kIsGroupMode)
|
||||
{
|
||||
dim3 grids = SageAttnKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 grids = SageAttnKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
}
|
||||
|
||||
// this is used to pattern-match internal kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockSageAttnPipelineEnum SageAttnPipelineEnum_,
|
||||
typename AttnMask_,
|
||||
ck_tile::BlockSageAttentionQuantScaleEnum QScaleEnum_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct sageattn_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto SageAttnPipelineEnum = SageAttnPipelineEnum_;
|
||||
using AttnMask = ck_tile::remove_cvref_t<AttnMask_>;
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float sageattn_fwd_(const ck_tile::stream_config&, sageattn_fwd_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct sageattn_fwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
quant_scale_enum qscale_type;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float sageattn_fwd(sageattn_fwd_traits, sageattn_fwd_args, const ck_tile::stream_config&);
|
||||
1154
example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp
Normal file
1154
example/ck_tile/49_sageattention/sageattn_fwd_runner.hpp
Normal file
File diff suppressed because it is too large
Load Diff
162
example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
Executable file
162
example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
Executable file
@@ -0,0 +1,162 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# SageAttention forward smoke tests - structure mirrors
|
||||
# example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
#
|
||||
# Run from the ComposableKernel *build* directory (after ninja), same as FMHA:
|
||||
# cd build && ninja tile_example_sageattn_fwd
|
||||
# bash ../example/ck_tile/49_sageattention/script/smoke_test_sageattn_fwd.sh
|
||||
#
|
||||
# Optional: VERBOSE=1 enables bash -x. CURR_FAILS_FILE / KNOWN_FAILS_FILE override fail logs.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
EXE_NAME=tile_example_sageattn_fwd
|
||||
EXE="$(find . -name "$EXE_NAME" -type f 2>/dev/null | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=${GPU_arch:-}
|
||||
if [ -z "$GPU_arch" ]; then
|
||||
GPU_arch=$(rocminfo 2>/dev/null | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}' || echo "unknown")
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"sageattn_fwd_fails_${GPU_arch}.txt"}
|
||||
rm -f "$CURR_FAILS_FILE"
|
||||
touch "$CURR_FAILS_FILE"
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/sageattn_fwd_known_fails_${GPU_arch}.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
if [ -z "${EXE:-}" ] || [ ! -x "$EXE" ]; then
|
||||
echo "ERROR: $EXE_NAME not found under cwd ($(pwd)). Build with: ninja $EXE_NAME" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_exe() {
|
||||
set +e
|
||||
$EXE "$@"
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ]; then
|
||||
echo "$EXE_NAME $*" >>"$CURR_FAILS_FILE"
|
||||
fi
|
||||
set -e
|
||||
}
|
||||
|
||||
# Core FP8xBF16 cases aligned with FMHA smoke_test_fwd.sh (lines 80-87): batch/group shapes,
|
||||
# masks, GQA, short seqlen, k-only pad. Sweeps blockscale (2) vs per-warp (3) and layouts.
|
||||
run_fp8bf16_smoke() {
|
||||
local qscale
|
||||
local perm
|
||||
for qscale in 2 3; do
|
||||
for perm in 0 1; do
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=2 -h_k=1 -d=128 -d_v=128 -s=55 -s_k=256 \
|
||||
-mask=1
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=3 -d=128 -s=100 -s_k=51 -mask=0
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=99 -s_k=256 \
|
||||
-mask=1
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1024 -s_k=256 \
|
||||
-mask=2
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -d_v=128 -s=3 -s_k=99 -mask=2
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=3 -h=2 -h_k=1 -d=128 -s=200 -s_k=520 \
|
||||
-mask=t:128,30
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=2 -h=1 -d=128 -s=99 -s_k=32 -mask=b:4,35
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=33 -s_k=0 -mask=2
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=$qscale -iperm=$perm -operm=$perm -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 \
|
||||
-s_kpad=32 -mask=2
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Extra FP8: explicit causal string, xformer window, per-tensor / per-thread quant, V col-major.
|
||||
run_fp8bf16_extras() {
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=t:-1,0
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=1 -operm=1 -vlayout=c -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=2 -h=4 -d=128 -s=256 -s_k=256 -mask=t
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=256 -s_k=256 -mask=xt:64
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=0
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=64 -s_k=64 -mask=0
|
||||
}
|
||||
|
||||
# Group mode + physical padding (same intent as FMHA run_padding_smoke_tests, Sage-only flags).
|
||||
run_group_and_padding_smoke() {
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
|
||||
# group + PERTHREAD: block_scale_seqstart_* must be allocated (same as bs/pw)
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=4 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=1 -b=3 -h=2 -h_k=1 -d=128 -s=50,60,40 -s_k=128,256,192 -mask=1
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=1 -b=4 -h=8 -h_k=8 -d=128 -s=1024,768,512,256 -s_k=1024,768,512,256 \
|
||||
-mask=0 -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
run_exe -prec=fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=4 -h=8 -d=128 -s=1024 -s_k=1024 -mask=0 \
|
||||
-q_eff_lens=960,512,384,256 -kv_eff_lens=960,512,384,256
|
||||
}
|
||||
|
||||
# BF16 (no quant): pipeline sanity only; not a shipped Sage mode (see example --help prec).
|
||||
run_bf16_pipeline_smoke() {
|
||||
run_exe -prec=bf16 -init=1 -qscale=n -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
|
||||
run_exe -prec=bf16 -init=1 -qscale=n -iperm=1 -operm=1 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=4 -h_k=1 -d=128 -s=256 -s_k=128 -mask=t:32,32
|
||||
}
|
||||
|
||||
# int8 / int4 x fp8xbf16 (hdim divisible by 8 for int4)
|
||||
run_int_quant_smoke() {
|
||||
run_exe -prec=i8fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=2 -h=2 -d=128 -s=128 -s_k=128 -mask=1
|
||||
run_exe -prec=i4fp8bf16 -init=3 -qscale=3 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME \
|
||||
$COMMON_ARGS -mode=0 -b=1 -h=2 -d=128 -s=128 -s_k=128 -mask=t
|
||||
}
|
||||
|
||||
if [ "${VERBOSE:-0}" = 1 ]; then
|
||||
set -x
|
||||
fi
|
||||
|
||||
run_fp8bf16_smoke
|
||||
run_fp8bf16_extras
|
||||
run_group_and_padding_smoke
|
||||
run_bf16_pipeline_smoke
|
||||
run_int_quant_smoke
|
||||
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f "$KNOWN_FAILS_FILE" ]; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" "$KNOWN_FAILS_FILE"; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$((known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$((new_fails_count + 1))
|
||||
fi
|
||||
done <"$CURR_FAILS_FILE"
|
||||
else
|
||||
new_fails_count=$(wc -l <"$CURR_FAILS_FILE")
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
if [ "$new_fails_count" -gt 0 ]; then
|
||||
cat "$CURR_FAILS_FILE"
|
||||
fi
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $((new_fails_count != 0))
|
||||
254
example/ck_tile/49_sageattention/utils.hpp
Normal file
254
example/ck_tile/49_sageattention/utils.hpp
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
enum class mode_enum
|
||||
{
|
||||
batch = 0,
|
||||
group
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, mode_enum mode)
|
||||
{
|
||||
return stream << (mode == mode_enum::batch ? "batch" : "group");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
inline std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
{
|
||||
std::vector<int32_t> seqstarts = {0};
|
||||
for(int32_t seqlen : seqlens)
|
||||
{
|
||||
seqstarts.push_back(seqstarts.back() + seqlen);
|
||||
}
|
||||
assert(seqstarts.size() == seqlens.size() + 1);
|
||||
return seqstarts;
|
||||
}
|
||||
|
||||
template <typename RandomEngine>
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min, // if not negative, clamp min
|
||||
int32_t seqlen_max, // if not negative, clamp max
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
|
||||
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
|
||||
assert(seqlen_min <= seqlen_max);
|
||||
|
||||
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
|
||||
|
||||
if(mode == mode_enum::group && 1 < count)
|
||||
{
|
||||
using size_type = std::vector<int32_t>::size_type;
|
||||
|
||||
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
|
||||
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
|
||||
|
||||
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
|
||||
auto next_step = std::bind(step_dist, std::ref(random_engine));
|
||||
|
||||
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
|
||||
{
|
||||
const size_type to_decrease = next_idx();
|
||||
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
|
||||
if(seqlens[to_decrease] == seqlen_min)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_type to_increase = (to_decrease + next_step()) % count;
|
||||
|
||||
if(seqlens[to_increase] >= seqlen_max)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
--seqlens[to_decrease];
|
||||
++seqlens[to_increase];
|
||||
}
|
||||
}
|
||||
|
||||
return seqlens;
|
||||
}
|
||||
|
||||
// return random integer generated uniformly in range [low, high]
|
||||
template <typename Int = int, typename RandomEngine>
|
||||
auto randint(Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
return dist(random_engine);
|
||||
}
|
||||
|
||||
// return random integers generated uniformly in range [low, high]
|
||||
template <typename Int, typename ForwardIterator, typename RandomEngine>
|
||||
auto randints(ForwardIterator first,
|
||||
ForwardIterator last,
|
||||
Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
|
||||
std::generate(first, last, [&] { return dist(random_engine); });
|
||||
}
|
||||
|
||||
/*
|
||||
* generate missing values in *_val randomly when the number of values is smaller than batch
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
|
||||
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
|
||||
* q_val=1,2,3,4 -> OK, but ignore exceed one
|
||||
*
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& q_pad_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = q_val[0];
|
||||
ck_tile::index_t k = k_val[0];
|
||||
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = [&] {
|
||||
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
|
||||
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
|
||||
|
||||
if(1 < batch && need_append_kvcache)
|
||||
{
|
||||
// to keep the original s_k value, we always use seqlen_k_max in first batch
|
||||
randints(std::next(seqlen_ks.begin()),
|
||||
seqlen_ks.end(),
|
||||
seqlen_k_min,
|
||||
seqlen_k_max,
|
||||
random_engine);
|
||||
return seqlen_ks;
|
||||
}
|
||||
|
||||
return seqlen_ks;
|
||||
}();
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
std::vector<ck_tile::index_t> s_qpad;
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
ck_tile::index_t q = q_val[idx];
|
||||
ck_tile::index_t k =
|
||||
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
|
||||
ck_tile::index_t kp =
|
||||
k_pad_val.empty()
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
ck_tile::index_t qp =
|
||||
q_pad_val.empty()
|
||||
? -1
|
||||
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
s_qpad.push_back(qp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q =
|
||||
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
|
||||
auto rem_k = generate_seqlens(
|
||||
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
|
||||
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
|
||||
RandomAccessIterator last,
|
||||
Int value,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
std::iota(first, last, value);
|
||||
std::shuffle(first, last, random_engine);
|
||||
}
|
||||
@@ -31,6 +31,7 @@ add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
add_subdirectory(42_mx_gemm)
|
||||
add_subdirectory(49_sageattention)
|
||||
add_subdirectory(50_sparse_attn)
|
||||
add_subdirectory(51_tile_distr_enc_reg_map)
|
||||
if(BUILD_CK_TILE_CSHUFFLE_LDS_BENCHMARKS)
|
||||
|
||||
@@ -530,4 +530,10 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <index_t swizzle_factor = 2>
|
||||
using WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockSageAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE = 2,
|
||||
PERWARP = 3,
|
||||
PERTHREAD = 4,
|
||||
};
|
||||
|
||||
template <BlockSageAttentionQuantScaleEnum>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::NO_SCALE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTENSOR>
|
||||
{
|
||||
static constexpr const char* name = "pertensor";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::BLOCKSCALE>
|
||||
{
|
||||
static constexpr const char* name = "blockscale";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERWARP>
|
||||
{
|
||||
static constexpr const char* name = "perwarp";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTHREAD>
|
||||
{
|
||||
static constexpr const char* name = "perthread";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
1026
include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp
Normal file
1026
include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockSageAttnPipelineEnum
|
||||
{
|
||||
QRKSVS = 0,
|
||||
QRKSVS_ASYNC,
|
||||
};
|
||||
|
||||
template <BlockSageAttnPipelineEnum>
|
||||
struct BlockSageAttnPipelineEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS>
|
||||
{
|
||||
static constexpr const char* name = "qr";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS_ASYNC>
|
||||
{
|
||||
static constexpr const char* name = "qr_async";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockSageAttnShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename AttnMask_,
|
||||
typename Traits_>
|
||||
struct BlockSageAttnPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockSageAttnShape = remove_cvref_t<BlockSageAttnShape_>;
|
||||
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
|
||||
using AttnMask = remove_cvref_t<AttnMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = BlockSageAttnShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockSageAttnShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockSageAttnShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
/// Must match host scale tensor layout (same values as TileSageAttnTraits for Sage kernels).
|
||||
static constexpr index_t kBlockScaleSizeQ = Traits::kBlockScaleSizeQ;
|
||||
static constexpr index_t kBlockScaleSizeK = Traits::kBlockScaleSizeK;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,861 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockSageAttentionPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using QGemmDataType = SageAttnQKGemmQDataType<Problem>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
|
||||
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
|
||||
static_assert(std::is_same_v<PDataType, VDataType>,
|
||||
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<PDataType, fp8_t>,
|
||||
"SageAttention pipeline requires PDataType = fp8_t");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<VDataType, fp8_t>,
|
||||
"SageAttention pipeline requires VDataType = fp8_t");
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
|
||||
|
||||
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
|
||||
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
|
||||
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
|
||||
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
|
||||
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
|
||||
// FP8 softmax shift constants to map softmax output into representable FP8 range
|
||||
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
|
||||
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
|
||||
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
|
||||
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
AttnMask mask,
|
||||
PositionEncoding /*position_encoding*/,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
[[maybe_unused]] const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// K tile in LDS
|
||||
KLdsDataType* k_lds_ptr = static_cast<KLdsDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
auto q_dram_window_reg =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
auto q = load_tile(q_dram_window_reg);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType =
|
||||
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
|
||||
SaccBlockTileType,
|
||||
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(start, end);
|
||||
}();
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here occ are all cleard, return it
|
||||
// Note: q loaded but no fence, ignore it.
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, kv_load_start},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = [&]() {
|
||||
if constexpr(std::is_same_v<QDataType, QGemmDataType>)
|
||||
return tile_elementwise_in(q_element_func, q);
|
||||
else
|
||||
{
|
||||
auto q_tile_tmp = make_static_distributed_tensor<QGemmDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
constexpr index_t kPackedSize = numeric_traits<QDataType>::PackedSize;
|
||||
constexpr index_t kUnaryOpSize = 8;
|
||||
static_assert(std::is_same_v<QDataType, ck_tile::pk_int4_t>);
|
||||
static_assert(kPackedSize == 2);
|
||||
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() ==
|
||||
decltype(q)::get_thread_buffer_size() * kPackedSize);
|
||||
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
|
||||
|
||||
using RawQType = typename QDataType::type;
|
||||
using SrcVectorType = ext_vector_t<RawQType, kUnaryOpSize / kPackedSize>;
|
||||
using DstVectorType = ext_vector_t<QGemmDataType, kUnaryOpSize>;
|
||||
constexpr index_t kVecSize =
|
||||
decltype(q_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
|
||||
static_assert(decltype(q)::get_thread_buffer_size() ==
|
||||
kVecSize * (kUnaryOpSize / kPackedSize));
|
||||
|
||||
const element_wise::PassThroughPack8 pass_through_pack8{};
|
||||
static_for<0, kVecSize, 1>{}([&](auto i) {
|
||||
pass_through_pack8(
|
||||
q_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(i),
|
||||
q.get_thread_buffer().template get_as<SrcVectorType>()[i]);
|
||||
});
|
||||
return q_tile_tmp;
|
||||
}
|
||||
}();
|
||||
|
||||
// prefetch K tile
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
// Use compile-time conditional for group barrier sequence
|
||||
// (No runtime lambda selection)
|
||||
auto schedule_gemm0 = [] {
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
|
||||
constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
|
||||
constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
|
||||
constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
|
||||
constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
|
||||
constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) *
|
||||
(kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
|
||||
if constexpr(get_warp_size() == 64 && kQKHeaddim == 256)
|
||||
{
|
||||
static_assert(NumMfmaInsts % 8 == 0);
|
||||
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
|
||||
static_assert(get_warp_size() % kGemm0MPerWarp == 0);
|
||||
constexpr index_t kWarpSz = get_warp_size();
|
||||
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
|
||||
// indexing)
|
||||
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
constexpr index_t kNumKScalesPW =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
|
||||
? kN0 / Problem::kBlockScaleSizeK
|
||||
: 1;
|
||||
constexpr index_t kNumKScalesPT =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
|
||||
? kN0 / Problem::kBlockScaleSizeK / 2
|
||||
: 1;
|
||||
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
|
||||
}
|
||||
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
|
||||
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
|
||||
}
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
auto s_acc_gemm = SaccBlockTileType{};
|
||||
const auto store_k_block_tile_to_lds = [&](const auto& k_block_tile_) {
|
||||
if constexpr(std::is_same_v<KDataType, KLdsDataType>)
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile_));
|
||||
else
|
||||
{
|
||||
auto k_block_tile_tmp = make_static_distributed_tensor<KLdsDataType>(
|
||||
k_dram_window.get_tile_distribution());
|
||||
using KBlockTileType = remove_cvref_t<decltype(k_block_tile_)>;
|
||||
constexpr index_t kPackedSize = numeric_traits<KDataType>::PackedSize;
|
||||
constexpr index_t kUnaryOpSize = 8;
|
||||
static_assert(std::is_same_v<KDataType, ck_tile::pk_int4_t>);
|
||||
static_assert(kPackedSize == 2);
|
||||
static_assert(decltype(k_block_tile_tmp)::get_thread_buffer_size() ==
|
||||
KBlockTileType::get_thread_buffer_size() * kPackedSize);
|
||||
static_assert(
|
||||
decltype(k_block_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
|
||||
|
||||
using RawKType = typename KDataType::type;
|
||||
using SrcVectorType = ext_vector_t<RawKType, kUnaryOpSize / kPackedSize>;
|
||||
using DstVectorType = ext_vector_t<KLdsDataType, kUnaryOpSize>;
|
||||
constexpr index_t kVecSize =
|
||||
decltype(k_block_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
|
||||
static_assert(KBlockTileType::get_thread_buffer_size() ==
|
||||
kVecSize * (kUnaryOpSize / kPackedSize));
|
||||
|
||||
const element_wise::PassThroughPack8 pass_through_pack8{};
|
||||
static_for<0, kVecSize, 1>{}([&](auto i) {
|
||||
pass_through_pack8(
|
||||
k_block_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(
|
||||
i),
|
||||
k_block_tile_.get_thread_buffer().template get_as<SrcVectorType>()[i]);
|
||||
});
|
||||
store_tile(k_lds_window, k_block_tile_tmp);
|
||||
}
|
||||
};
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
clear_tile(s_acc_gemm); // initialize C
|
||||
store_k_block_tile_to_lds(k_block_tile);
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_k_block_tile_to_lds(k_block_tile); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kM0, (k0_loops - 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
block_sync_lds();
|
||||
|
||||
store_k_block_tile_to_lds(k_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
}
|
||||
|
||||
// Convert GEMM output to SaccDataType for softmax (if needed)
|
||||
auto s_acc = [&]() {
|
||||
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
|
||||
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
|
||||
{
|
||||
return s_acc_gemm; // No conversion needed (e.g., float -> float)
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// PERTHREAD: kBlockScaleSizeK=16
|
||||
// The s_acc tile distribution is determined by
|
||||
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
|
||||
// each thread processes exactly 16 consecutive elements in the K dimension. This
|
||||
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
|
||||
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
|
||||
// elements to K scale indices.
|
||||
static_assert(Problem::kBlockScaleSizeK == 16,
|
||||
"PERTHREAD: kBlockScaleSizeK must be 16");
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures the distribution has 16 consecutive K elements per
|
||||
// thread
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERTHREAD requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"16 consecutive K elements");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPT] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
|
||||
const index_t scale_idx = col_offset >> 4;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
|
||||
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
|
||||
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
|
||||
// grouping is correct
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERWARP requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"correct K element grouping");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPW] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
|
||||
// elements Divide by 32 (>>5) to map to K scale groups
|
||||
// (kBlockScaleSizeK=64)
|
||||
const index_t scale_idx = col_offset >> 5;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// dequant: combine q_descale (in s_acc_element_func) with k_descale
|
||||
auto s_acc_element_func_ = [&]() {
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
}
|
||||
// STAGE 2, scale_s, mask, softmax
|
||||
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// exp2(s - m + shift) = exp2(s - (m - shift)); pertensor path uses scale_s on s,m
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto m_new = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * m_new;
|
||||
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
|
||||
// Update l and rescale o_acc
|
||||
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
|
||||
// Apply per-channel v_descale after the loop (before normalization)
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v);
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
|
||||
// before normalization)
|
||||
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
|
||||
block_sync_lds();
|
||||
|
||||
// V is col-major, each column (channel) has its own scale
|
||||
// o_acc shape: [M0, N1] where N1 is hdim_v
|
||||
// v_descale_ptr points to per-channel scales [hdim_v]
|
||||
// Load v_descale to LDS for better memory access pattern
|
||||
// Reuse K/V LDS space (they're no longer needed)
|
||||
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
|
||||
|
||||
// Cooperatively load v_descale to LDS
|
||||
const index_t num_threads = kBlockSize;
|
||||
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
|
||||
{
|
||||
v_descale_lds[i] = v_descale_ptr[i];
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// Get the global tile index for the N1 (channel) dimension
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
o_acc.get_tile_distribution(), i_j_idx);
|
||||
const index_t channel_idx = tile_idx.at(number<1>{});
|
||||
const float v_scale = v_descale_lds[channel_idx];
|
||||
o_acc(i_j_idx) *= v_scale;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
// When masking, the denominator can be zero; guard the normalization
|
||||
// so we do not divide by zero after a fully masked row.
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
AttnMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
[[maybe_unused]] const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
q_descale_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,873 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
|
||||
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy>
|
||||
struct BlockSageAttentionPipelineQRKSVSAsync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
|
||||
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
|
||||
static_assert(std::is_same_v<PDataType, VDataType>,
|
||||
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<PDataType, fp8_t>,
|
||||
"SageAttention pipeline requires PDataType = fp8_t");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<VDataType, fp8_t>,
|
||||
"SageAttention pipeline requires VDataType = fp8_t");
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
|
||||
|
||||
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
|
||||
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
|
||||
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
|
||||
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
|
||||
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
|
||||
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
|
||||
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
|
||||
Problem::kPadHeadDimV == true);
|
||||
static constexpr bool kPadSeqLenQ = true;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
// FP8 softmax shift constants to map softmax output into representable FP8 range
|
||||
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
|
||||
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
|
||||
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
|
||||
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 192)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_async";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& /*k_element_func*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
AttnMask mask,
|
||||
PositionEncoding /*position_encoding*/,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
[[maybe_unused]] const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds_store = generate_tuple(
|
||||
[&](auto i_buf) {
|
||||
return make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
|
||||
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
|
||||
{0, 0, 0});
|
||||
},
|
||||
number<Policy::NumKVLdsBuffers>{});
|
||||
|
||||
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_load =
|
||||
make_tile_window(k_lds_Load_view,
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
q_dram_window.init_raw();
|
||||
|
||||
// TODO: we use async Copy for K, which is inline asm
|
||||
// a side effect is we have to use inline asm for q as well
|
||||
auto q = decltype(load_tile(q_dram_window)){};
|
||||
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
|
||||
// however, q would be cleared in the constructor of static distributed tensor
|
||||
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
|
||||
load_tile_raw(q, q_dram_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType =
|
||||
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
|
||||
SaccBlockTileType,
|
||||
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(start, end);
|
||||
}();
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
|
||||
// otherwise will have compute error(maybe compiler bug?)
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
return o_acc;
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
k_dram_window.init_raw();
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
constexpr auto k_pre_np = bool_constant<false>{};
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, kv_load_start},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
|
||||
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
|
||||
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
|
||||
static_assert(kGemm0MPerWarp == 32);
|
||||
constexpr index_t kWarpSz = get_warp_size();
|
||||
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
|
||||
// indexing)
|
||||
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
constexpr index_t kNumKScalesPW =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
|
||||
? kN0 / Problem::kBlockScaleSizeK
|
||||
: 1;
|
||||
constexpr index_t kNumKScalesPT =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
|
||||
? kN0 / Problem::kBlockScaleSizeK / 2
|
||||
: 1;
|
||||
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
|
||||
}
|
||||
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
|
||||
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
|
||||
}
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
auto s_acc_gemm = SaccBlockTileType{};
|
||||
clear_tile(s_acc_gemm); // initialize C
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
async_load_fence(k_dram_window.get_num_of_access());
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(
|
||||
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
get_slice_tile(k_lds_load,
|
||||
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
|
||||
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: this to fix a bug when loop smaller than 2,
|
||||
// the following fence/barrier will be scheduled inside 1st loop
|
||||
if constexpr(k0_loops <= 2)
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
async_load_fence();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
{ // tail
|
||||
gemm_0(
|
||||
s_acc_gemm,
|
||||
get_slice_tile(
|
||||
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(k_lds_load,
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
// Convert GEMM output to SaccDataType for softmax (if needed)
|
||||
auto s_acc = [&]() {
|
||||
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
|
||||
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
|
||||
{
|
||||
return s_acc_gemm; // No conversion needed (e.g., float -> float)
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// PERTHREAD: kBlockScaleSizeK=16
|
||||
// The s_acc tile distribution is determined by
|
||||
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
|
||||
// each thread processes exactly 16 consecutive elements in the K dimension. This
|
||||
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
|
||||
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
|
||||
// elements to K scale indices.
|
||||
static_assert(Problem::kBlockScaleSizeK == 16,
|
||||
"PERTHREAD: kBlockScaleSizeK must be 16");
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures the distribution has 16 consecutive K elements per
|
||||
// thread
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERTHREAD requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"16 consecutive K elements");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPT] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
|
||||
const index_t scale_idx = col_offset >> 4;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
|
||||
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
|
||||
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
|
||||
// grouping is correct
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERWARP requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"correct K element grouping");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPW] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
|
||||
// elements Divide by 32 (>>5) to map to K scale groups
|
||||
// (kBlockScaleSizeK=64)
|
||||
const index_t scale_idx = col_offset >> 5;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// dequant: combine q_descale (in s_acc_element_func) with k_descale
|
||||
auto s_acc_element_func_ = [&]() {
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
}
|
||||
// STAGE 2, scale_s, mask, softmax
|
||||
// logits_soft_cap is always disabled
|
||||
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store
|
||||
// Only needed when K tail and V use the same LDS buffer
|
||||
if constexpr(LdsSeq.at(number<k0_loops - 1>{}) == LdsSeq.at(number<k0_loops>{}))
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
|
||||
store_tile(
|
||||
v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// exp2(s - m + shift) = exp2(s - (m - shift))
|
||||
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// logits_soft_cap is always disabled
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto m_new = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * m_new;
|
||||
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
|
||||
// Update l and rescale o_acc
|
||||
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
|
||||
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#else
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
|
||||
// Apply per-channel v_descale after the loop (before normalization)
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
auto v_lds_window_tmp = get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp = get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
|
||||
}
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
}
|
||||
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
|
||||
// before normalization)
|
||||
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
|
||||
block_sync_lds();
|
||||
|
||||
// V is col-major, each column (channel) has its own scale
|
||||
// o_acc shape: [M0, N1] where N1 is hdim_v
|
||||
// v_descale_ptr points to per-channel scales [hdim_v]
|
||||
// Load v_descale to LDS for better memory access pattern
|
||||
// Reuse K/V LDS space (they're no longer needed)
|
||||
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
|
||||
|
||||
// Cooperatively load v_descale to LDS
|
||||
const index_t num_threads = kBlockSize;
|
||||
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
|
||||
{
|
||||
v_descale_lds[i] = v_descale_ptr[i];
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// Get the global tile index for the N1 (channel) dimension
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
o_acc.get_tile_distribution(), i_j_idx);
|
||||
const index_t channel_idx = tile_idx.at(number<1>{});
|
||||
const float v_scale = v_descale_lds[channel_idx];
|
||||
o_acc(i_j_idx) *= v_scale;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
AttnMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
q_descale_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
using BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy =
|
||||
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ true,
|
||||
/* NumPrefetchK = */ 3,
|
||||
/* NumPrefetchV = */ 3>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,857 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetPackedSize()
|
||||
{
|
||||
return numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetLogicalVectorSize(index_t bytes)
|
||||
{
|
||||
return (bytes / sizeof(remove_cvref_t<T>)) * GetPackedSize<T>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
using SageAttnQKGemmQDataType =
|
||||
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::QDataType>>,
|
||||
fp8_t,
|
||||
remove_cvref_t<typename Problem::QDataType>>;
|
||||
|
||||
template <typename Problem>
|
||||
using SageAttnQKGemmKDataType =
|
||||
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::KDataType>>,
|
||||
fp8_t,
|
||||
remove_cvref_t<typename Problem::KDataType>>;
|
||||
|
||||
template <bool QLoadOnce_>
|
||||
struct BlockSageAttnPipelineQRCustomPolicy;
|
||||
|
||||
template <>
|
||||
struct BlockSageAttnPipelineQRCustomPolicy</* QLoadOnce = */ true>
|
||||
{
|
||||
static constexpr bool QLoadOnce = true;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// TODO: GetAlignment*() currently didn't consider if need padding or not
|
||||
// so in pipeline still need check padding requirement
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
constexpr index_t MaxVectorSize = GetLogicalVectorSize<typename Problem::QDataType>(16);
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::template MakeABlockTileDistribution<
|
||||
Problem::BlockSageAttnShape::kM0,
|
||||
Problem::BlockSageAttnShape::kSubQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using QKGemmQDataType = SageAttnQKGemmQDataType<Problem>;
|
||||
using QKGemmKDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
// int8 MFMA accumulates to int32, but SaccDataType is float for softmax
|
||||
using GemmAccDataType =
|
||||
std::conditional_t<(std::is_same_v<QKGemmQDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmQDataType, signed char>) &&
|
||||
(std::is_same_v<QKGemmKDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmKDataType, signed char>),
|
||||
int32_t,
|
||||
typename Problem::SaccDataType>;
|
||||
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<QKGemmQDataType,
|
||||
QKGemmKDataType,
|
||||
GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
|
||||
Problem::BlockSageAttnShape::kN0,
|
||||
Problem::BlockSageAttnShape::kK0>,
|
||||
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockSageAttnShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(get_warp_size() == 64 && std::is_same_v<QKGemmQDataType, fp8_t> &&
|
||||
std::is_same_v<QKGemmKDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else if constexpr(get_warp_size() == 64 &&
|
||||
(std::is_same_v<QKGemmQDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmQDataType, signed char>) &&
|
||||
(std::is_same_v<QKGemmKDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmKDataType, signed char>))
|
||||
{
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// Use special int8 MFMA with K iteration (similar to FP8)
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<
|
||||
QKGemmQDataType,
|
||||
QKGemmKDataType,
|
||||
GemmAccDataType,
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
QKGemmQDataType,
|
||||
QKGemmKDataType,
|
||||
GemmAccDataType,
|
||||
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchK_, index_t NumPrefetchV_>
|
||||
struct BlockSageAttnPipelineQRKSVSCustomPolicy : BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>
|
||||
{
|
||||
static constexpr bool AsyncCopy = AsyncCopy_;
|
||||
|
||||
static constexpr index_t NumPrefetchK = NumPrefetchK_;
|
||||
static constexpr index_t NumPrefetchV = NumPrefetchV_;
|
||||
|
||||
static constexpr index_t NumKVLdsBuffers = max(NumPrefetchK, NumPrefetchV);
|
||||
|
||||
using QXPolicy = BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>;
|
||||
|
||||
template <index_t k_prefetches_, index_t v_prefetches_, index_t k_loops_, index_t v_loops_>
|
||||
struct LdsBufferSequence
|
||||
{
|
||||
static constexpr index_t num_lds_buffers_ = max(k_prefetches_, v_prefetches_);
|
||||
static constexpr index_t ceil_ = ((v_loops_ - 1) / num_lds_buffers_) * num_lds_buffers_;
|
||||
|
||||
// for qr_ks_vs_async, the Lds buffer assigned to last gemm_1 iteration of V should not
|
||||
// overlap with the Lds buffers used by first two gemm_0 iterations of K
|
||||
static constexpr auto Make()
|
||||
{
|
||||
// ensure v_loop_-1 is assigned to num_lds_buffers-1
|
||||
return transform_sequences(
|
||||
[&](auto i) {
|
||||
if(i < k_loops_)
|
||||
return i % num_lds_buffers_;
|
||||
else
|
||||
return ((num_lds_buffers_ - 1) + (i - k_loops_ + ceil_ - (v_loops_ - 1))) %
|
||||
num_lds_buffers_;
|
||||
},
|
||||
typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{});
|
||||
};
|
||||
|
||||
using type = remove_cvref_t<decltype(Make())>;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 4, 2> { using type = sequence<1, 2, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 2, 4> { using type = sequence<1, 2, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
|
||||
// clang-format on
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLdsBufferSequence()
|
||||
{
|
||||
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
|
||||
|
||||
constexpr index_t kN0 = BlockSageAttnShape::kN0;
|
||||
constexpr index_t kK0 = BlockSageAttnShape::kK0;
|
||||
constexpr index_t kK1 = BlockSageAttnShape::kK1;
|
||||
constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
using KDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
return GetLogicalVectorSize<KDataType>(16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
if constexpr(AsyncCopy)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4
|
||||
#else
|
||||
constexpr index_t MaxLoadSizeInBytes = 4; // dword
|
||||
#endif
|
||||
|
||||
return GetLogicalVectorSize<KDataType>(MaxLoadSizeInBytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
|
||||
return kMaxVecLoad;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (total_pixels / kMinVecLoad);
|
||||
|
||||
return kVecLoad;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kMaxVecLoad;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType);
|
||||
return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
|
||||
{
|
||||
// this function assume K/V can share smem
|
||||
constexpr index_t SingleKSize = [&]() {
|
||||
if constexpr(!AsyncCopy)
|
||||
{
|
||||
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock &&
|
||||
WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
}();
|
||||
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
// TODO: this is used for non async copy desc. unify in the future
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kNPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, index_t IBuf = 0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
|
||||
{
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return k_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
|
||||
{
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKVLdsBuffers>{}, // num_buffers
|
||||
number<NumIssues>{}, // n0
|
||||
number<NumWarps>{}, // n2
|
||||
number<LaneGroups>{}, // n1
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<BufferSize>{},
|
||||
number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumKVLdsBuffers>{},
|
||||
number<NumIssues>{},
|
||||
number<LaneGroups>{},
|
||||
number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumKVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock / NPerRow>{},
|
||||
number<NPerRow>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
|
||||
number<PixelsPerRow + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(
|
||||
number<NumKVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
// TODO: assume Q is in register
|
||||
// TODO: assume K and V share smem buffers
|
||||
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
constexpr index_t single_smem_size =
|
||||
GetSingleSmemElementSpaceSize<Problem>() * sizeof(KLdsDataType);
|
||||
|
||||
return QXPolicy::template GetSmemSizeQ<Problem>() + single_smem_size * NumKVLdsBuffers;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeKV<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
if constexpr(!AsyncCopy)
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr index_t N0 = NumIssues;
|
||||
constexpr index_t N1 = LaneGroups;
|
||||
constexpr index_t N2 = NumWarps;
|
||||
constexpr index_t K0 = LanesPerK;
|
||||
constexpr index_t K1 = KVector;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
static_assert(kNPerBlock % 16 == 0);
|
||||
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1 / K0;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 2, 1>, // N0 K2 N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
static_assert(N0 != 0);
|
||||
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>, // N0 K1
|
||||
sequence<0, 1>>{});
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock)
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerBlock % 16 == 0);
|
||||
constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16;
|
||||
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t K1_m = kKPerIter / K2;
|
||||
constexpr index_t N2_m = get_warp_size() / K1_m;
|
||||
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N0 K2
|
||||
sequence<0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kNPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor()
|
||||
{
|
||||
// This descriptor only used when V layout is seqlen * hdim
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
static_assert(kNPerBlock % 16 == 0);
|
||||
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1 / K0;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 1, 2>, // N0 K2 <-> N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kNumGemm1Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
|
||||
Problem::BlockSageAttnShape::kN1,
|
||||
Problem::BlockSageAttnShape::kK1>,
|
||||
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockSageAttnShape::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::PDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
}
|
||||
}();
|
||||
|
||||
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using BlockSageAttentionPipelineQRKSVSDefaultPolicy =
|
||||
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t Headdim>
|
||||
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
|
||||
{
|
||||
if constexpr(Headdim == 48)
|
||||
return 48;
|
||||
else if constexpr(Headdim == 80)
|
||||
return 96;
|
||||
else if constexpr(Headdim == 96)
|
||||
return 128;
|
||||
else if constexpr(Headdim == 160)
|
||||
return 256;
|
||||
else if constexpr(Headdim == 192)
|
||||
return 192;
|
||||
else if constexpr(is_power_of_two_integer(Headdim))
|
||||
return Headdim;
|
||||
else
|
||||
static_assert(Headdim == 0,
|
||||
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
typename Gemm1BlockWarps_,
|
||||
typename Gemm1WarpTile_,
|
||||
bool IsVLayoutRowMajor_>
|
||||
struct TileSageAttnShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
|
||||
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
|
||||
|
||||
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
|
||||
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* padding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* padding for hdim_v */,
|
||||
BlockSageAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileSageAttnTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
|
||||
/// Tokens per Q/K descale along seqlen. Fine-to-coarse: PERTHREAD, PERWARP, then 128 for Q
|
||||
/// (BLOCKSCALE / no_scale / pertensor). K: PERWARP 64, BLOCKSCALE 128, else 128.
|
||||
static constexpr index_t kBlockScaleSizeQ =
|
||||
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 4
|
||||
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 32
|
||||
: 128;
|
||||
static constexpr index_t kBlockScaleSizeK =
|
||||
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 16
|
||||
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 64
|
||||
: 128;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
17
include/ck_tile/ops/sageattn.hpp
Normal file
17
include/ck_tile/ops/sageattn.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
Reference in New Issue
Block a user