[CK_TILE][FMHA] Add sparse attention VSA (#3341)

* add sparse attention VSA

* fix the pre-commit

* Add jenga test and pre-commit

* add bf16 for vsa

* add jenga support bf16

* remove lse arg

* split kernel code to block & kernel

* fix the pre-commit

* fix the pre-commit

* fix the copyrights

* fix the copyright

* fix the copyright & rename block to pipeline

* fix the copyright and pipeline

* remove lse & dropout & add fmt

* fix the jenga&VSA code review

* remove the useless code & resolved the comments

* remove useless code

* remove useless code

* Clean up code

* Remove more unused code

* Re-format .hpp

* Refactor codegen scripts

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
jiangyon.ren
2026-01-31 00:59:47 +08:00
committed by GitHub
parent 2377a62837
commit 4d2f8c111e
22 changed files with 6058 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CMakeLists.txt for sparse attention (Jenga and VSA)
# Use SUPPORTED_GPU_TARGETS directly
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})
message(STATUS "Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}")
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine Sparse Attention: No supported GPU targets found")
return()
endif()
message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")
# Code generation scripts
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
)
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
# ============================================================================
# Jenga Sparse Attention
# ============================================================================
set(SPARSE_ATTN_JENGA_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd_jenga
--receipt 600
)
# Generate list of Jenga kernels (at configure time, only list)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate Jenga kernel list")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt SPARSE_ATTN_JENGA_GEN_BLOBS)
# Generate Jenga kernel source files at build time
add_custom_command(
OUTPUT ${SPARSE_ATTN_JENGA_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile Jenga Sparse Attention kernels"
)
message(STATUS "Jenga kernel files to be generated: ${SPARSE_ATTN_JENGA_GEN_BLOBS}")
# Jenga Instances
set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances")
add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARSE_ATTN_JENGA_GEN_BLOBS}
${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp
)
target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# Jenga Example executable
set(EXAMPLE_JENGA_SPARSE_ATTN "tile_example_jenga_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
# ============================================================================
# VSA Sparse Attention
# ============================================================================
set(SPARSE_ATTN_VSA_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd_vsa
--receipt 600
)
# Generate list of VSA kernels (at configure time, only list)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate VSA kernel list")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS)
# Generate VSA kernel source files at build time
add_custom_command(
OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile VSA Sparse Attention kernels"
)
message(STATUS "VSA kernel files to be generated: ${SPARSE_ATTN_VSA_GEN_BLOBS}")
# VSA Instances
set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances")
add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARSE_ATTN_VSA_GEN_BLOBS}
${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp
)
target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# VSA Example executable
set(EXAMPLE_VSA_SPARSE_ATTN "tile_example_vsa_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,3 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,73 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16": "FmhaSparseFwdFp16",
"bf16": "FmhaSparseFwdBf16",
}
_MASK_SIMPLIFIED_MAP = {
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no": "FmhaMasks::NoMask",
"causal": "FmhaMasks::CausalMask",
"generic": "FmhaMasks::GenericMask",
}
def get_mask_map(mask: str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
MODE_MAP = {"batch": "false"}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
}
PIPELINE_ENUM_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -0,0 +1,3 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,867 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
#include "kernel/fmha_fwd_jenga_kernel.hpp"
"""
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_jenga_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize( # fmt: skip
32,
32,
128,
128,
32,
128,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
64,
32,
128,
16,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# jenga fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async":
continue
k = FmhaFwdKernel(
F_idx=2,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,867 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
#include "kernel/fmha_fwd_vsa_kernel.hpp"
"""
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_vsa_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize( # fmt: skip
32,
32,
128,
128,
32,
128,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
64,
32,
128,
16,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# vsa fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async_vsa":
continue
k = FmhaFwdKernel(
F_idx=1,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,328 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "01_fmha/mask.hpp"
#include <type_traits>
#include <utility>
#include <variant>
namespace ck_tile {
inline bool is_load_tr_supported() { return is_gfx95_supported(); }
} // namespace ck_tile
struct FmhaSparseFwdFp16
{
};
struct FmhaSparseFwdBf16
{
};
template <typename DataType>
struct FmhaSparseFwdTypeConfig;
template <>
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
// Note: The following types are required by BlockFmhaPipelineProblem but not used
// by sparse attention (bias, dropout, LSE are not supported).
using BiasDataType = ck_tile::half_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float;
};
template <>
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
// Note: The following types are required by BlockFmhaPipelineProblem but not used
// by sparse attention (bias, dropout, LSE are not supported).
using BiasDataType = ck_tile::bf16_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// jenga
struct fmha_jenga_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* block_relation_onehot_ptr; // one-hot block map [B,H,Q_blk,K_blk], 1=active
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
// Dropout is not supported for sparse attention; keep args minimal.
};
// vsa
struct fmha_vsa_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk]
const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk]
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
// Dropout is not supported for sparse attention; keep args minimal.
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.block_relation_onehot_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.lut_ptr,
args.valid_block_num_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_>
struct fmha_jenga_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
};
struct fmha_jenga_fwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_v_rowmajor;
mask_enum mask_type;
// TODO: padding check is inside this api
};
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&);
// VSA uses the same traits structure as Jenga; aliases for clarity
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_>
using fmha_vsa_fwd_traits_ = fmha_jenga_fwd_traits_<HDim_,
DataType_,
kM0_,
kN0_,
kK0_,
kN1_,
kK1_,
kK0BlockLength_,
kIsVLayoutRowMajor_,
FmhaPipelineEnum_,
kHasLogitsSoftCap_,
FmhaMask_,
kPadS_,
kPadSK_,
kPadD_,
kPadDv_,
kUseTrLoad_>;
using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits;
float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&);

View File

@@ -0,0 +1,166 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
from typing import List, Optional
import codegen.ops
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = "fmha_"
handlers = dict(
[
(
op.__name__[len(unwanted_prefix) :]
if op.__name__.startswith(unwanted_prefix)
else op.__name__,
(op.list_blobs, op.write_blobs),
)
for op in ops
]
)
assert 0 < len(handlers)
def write_blobs(
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK fmha kernel",
)
parser.add_argument(
"-d",
"--direction", # we keep 'direction' option for backward compatibility
"-a",
"--api",
default="fwd_jenga",
required=False,
help="supply API(s) to generate (default: fwd). separated by comma.",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory",
)
parser.add_argument(
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default="",
required=False,
help="filter out kernels that need to generate, using fnmatch module",
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic",
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n"
+ " 1: generate more instance to cover all hdim\n"
+ " 2: Only generate instance for Flash attention integration\n"
+ " 4: Only generate instance for PyTorch integration\n"
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
)
parser.add_argument(
"--optdim",
default="-1",
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
+ "eg. --optdim=32,64,128,256",
)
args = parser.parse_args()
api_list = args.direction.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
if args.list_blobs is not None:
list_blobs(
args.list_blobs,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)
else:
write_blobs(
args.output_dir,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)

View File

@@ -0,0 +1,199 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "jenga_sparse_attention.h"
#include "fmha_fwd_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"Jenga sparse attention supports fp16/bf16 only.");
// Determine data type string based on template parameter
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
if(max_seqlen_k == 0)
max_seqlen_k = seqlen_k;
bool is_v_rowmajor = true;
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
const ck_tile::index_t shape_seqlen_q = seqlen_q;
const ck_tile::index_t shape_seqlen_k = seqlen_k;
ck_tile::stream_config stream_config{nullptr,
false, // time_kernel
log_level,
0,
1,
false};
// Create device memory and copy data to device
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
v_buf.ToDevice(TV.data());
block_relation_buf.ToDevice(Tblock_relation_onehot.data());
const auto init_args = [&](auto& args) {
assert(nhead % nhead_k == 0);
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// Use device buffer pointers instead of host tensor data pointers
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // batch mode only
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k; // batch mode only
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.stride_o = stride_o;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
// Dropout not supported for sparse attention.
};
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
};
fmha_jenga_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_jenga_fwd_args args;
init_args(args);
fmha_jenga_fwd(fmha_traits, args, stream_config);
// Copy output back to host without changing tensor shape
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
return Y;
}
// Explicit template instantiations
template ck_tile::HostTensor<ck_tile::half_t>
jenga_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<uint8_t>&,
ck_tile::HostTensor<ck_tile::half_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);
template ck_tile::HostTensor<ck_tile::bf16_t>
jenga_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<uint8_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <optional>
#include <cstdint>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level = 0);
template <typename DataType_>
ck_tile::HostTensor<DataType_> vsa_sparse_attention(
const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
const ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level = 0);

View File

@@ -0,0 +1,423 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Test for jenga_sparse_attention function
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include <chrono>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "jenga_sparse_attention.h"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
// Get error tolerance based on data type
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
// bf16 accumulation/rounding can be noisier in sparse patterns
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
// Parse arguments
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t block_size = arg_parser.get_int("block_size");
float sparsity = arg_parser.get_float("sparsity");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Handle default values
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
ck_tile::index_t BLKQ = block_size;
ck_tile::index_t BLKK = block_size;
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "Jenga kernel instances are generated for block_size=128 and hdim=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
// Calculate number of Q and K blocks
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[Jenga Sparse Attention Test]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
<< std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " sparsity: " << sparsity << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors (using BHSD layout when i_perm=true)
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
// Block relation onehot: [B, H, Q_blocks, K_blocks]
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
// Initialize tensors with random values
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Initialize block_relation_onehot with sparse pattern
std::mt19937 rng(seed + 100);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
ck_tile::index_t total_blocks = 0;
ck_tile::index_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
bool is_diagonal = (qb == kb && qb < num_k_blocks);
bool random_active = (dist(rng) > sparsity);
if(is_diagonal || random_active)
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
active_blocks++;
}
else
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
}
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
// Run kernel
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
try
{
if(kname)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
// Warmup
for(int i = 0; i < warmup; ++i)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
// Benchmark
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; ++i)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
// Validation
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
std::cout << "Computing reference output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
// Compare results
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -0,0 +1,486 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Test for vsa_sparse_attention function
// Based on the Python test: test_jenga_attention.py
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "jenga_sparse_attention.h"
#include "fmha_fwd_trek.hpp"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
// Convert block_relation_onehot to LUT format (similar to triton_block_map_to_lut_kernel)
template <typename T>
void block_map_to_lut(
const ck_tile::HostTensor<T>& block_map, // [B, H, Q_blocks, K_blocks]
ck_tile::HostTensor<int32_t>& lut, // [B, H, Q_blocks, K_blocks] - int32_t for kernel
ck_tile::HostTensor<int32_t>& valid_block_num, // [B, H, Q_blocks] - int32_t for kernel
ck_tile::index_t num_block_k)
{
auto lengths = block_map.get_lengths();
ck_tile::index_t B = lengths[0];
ck_tile::index_t H = lengths[1];
ck_tile::index_t Q = lengths[2];
for(ck_tile::index_t b = 0; b < B; ++b)
{
for(ck_tile::index_t h = 0; h < H; ++h)
{
for(ck_tile::index_t q = 0; q < Q; ++q)
{
int32_t valid_count = 0;
int32_t prev_block = 0;
for(ck_tile::index_t k = 0; k < num_block_k; ++k)
{
T cur_block = block_map(b, h, q, k);
if(static_cast<float>(cur_block) > 0.5f)
{ // Check if block is active
lut(b, h, q, valid_count) = static_cast<int32_t>(k - prev_block);
valid_count++;
prev_block = static_cast<int32_t>(k);
}
}
valid_block_num(b, h, q) = valid_count;
}
}
}
}
// Get error tolerance based on data type
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
// bf16 accumulation/rounding can be noisier in sparse patterns
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
// Parse arguments
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t block_size = arg_parser.get_int("block_size");
float sparsity = arg_parser.get_float("sparsity");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Handle default values
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
ck_tile::index_t BLKQ = block_size;
ck_tile::index_t BLKK = block_size;
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "VSA kernel instances are generated for block_size=128 and hdim=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
// Calculate number of Q and K blocks
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[VSA Sparse Attention Test]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
<< std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " sparsity: " << sparsity << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors (using BHSD layout when i_perm=true)
// Q: [B, H, S_q, D]
// K: [B, H_k, S_k, D]
// V: [B, H_k, S_k, D_v]
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
// Block relation onehot: [B, H, Q_blocks, K_blocks]
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
// LUT and valid_block_num (output of block_map_to_lut) - must be int32_t for kernel
ck_tile::HostTensor<int32_t> lut_host({batch, nhead, num_q_blocks, num_k_blocks});
ck_tile::HostTensor<int32_t> valid_block_num_host({batch, nhead, num_q_blocks});
// Initialize tensors with random values
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Initialize block_relation_onehot with sparse pattern
std::mt19937 rng(seed + 100);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
ck_tile::index_t total_blocks = 0;
ck_tile::index_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
// Each Q block always attends to its diagonal K block (if exists)
// Plus random blocks based on sparsity
bool is_diagonal = (qb == kb && qb < num_k_blocks);
bool random_active = (dist(rng) > sparsity);
if(is_diagonal || random_active)
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
active_blocks++;
}
else
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
}
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
// Convert block_relation_onehot to LUT format
std::cout << "Converting block map to LUT format..." << std::endl;
block_map_to_lut(block_relation_onehot, lut_host, valid_block_num_host, num_k_blocks);
// vsa_sparse_attention handles device memory internally
// Run kernel
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
try
{
// Print kernel name once by invoking with log_level=1.
// This is separate from warmup/benchmark to avoid polluting timing.
if(kname)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
// Warmup
for(int i = 0; i < warmup; ++i)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
// Benchmark
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; ++i)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
// Note: vsa_sparse_attention already returns output in output_host
// Validation
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
// Compute scale factor
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
// Run reference implementation
std::cout << "Computing reference output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
// Compare results
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "jenga_sparse_attention.h"
#include "fmha_fwd_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
template <typename DataType_>
ck_tile::HostTensor<DataType_>
vsa_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
const ck_tile::HostTensor<int32_t>& TKV_blocks,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"VSA sparse attention supports fp16/bf16 only.");
// Determine data type string based on template parameter
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
if(max_seqlen_k == 0)
max_seqlen_k = seqlen_k;
bool is_v_rowmajor = true;
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
const ck_tile::index_t shape_seqlen_q = seqlen_q;
const ck_tile::index_t shape_seqlen_k = seqlen_k;
ck_tile::stream_config stream_config{nullptr,
false, // time_kernel
log_level,
0,
1,
false};
// Create device memory and copy data to device
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes());
ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
v_buf.ToDevice(TV.data());
lut_buf.ToDevice(TKV_block_idx.data());
valid_block_num_buf.ToDevice(TKV_blocks.data());
const auto init_args = [&](auto& args) {
assert(nhead % nhead_k == 0);
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// Use device buffer pointers instead of host tensor data pointers
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.lut_ptr = lut_buf.GetDeviceBuffer();
args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // batch mode only
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k; // batch mode only
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.stride_o = stride_o;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
// Dropout not supported for sparse attention.
};
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
};
fmha_vsa_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_vsa_fwd_args args;
init_args(args);
fmha_vsa_fwd(fmha_traits, args, stream_config);
// Copy output back to host without changing tensor shape
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
return Y;
}
// Explicit template instantiations
template ck_tile::HostTensor<ck_tile::half_t>
vsa_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<int32_t>&,
const ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<ck_tile::half_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);
template ck_tile::HostTensor<ck_tile::bf16_t>
vsa_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<int32_t>&,
const ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);

View File

@@ -30,4 +30,5 @@ add_subdirectory(36_pooling)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)
add_subdirectory(50_sparse_attn)

View File

@@ -44,6 +44,11 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
return __builtin_amdgcn_readfirstlane(value);
}
__device__ inline uint32_t amd_wave_read_first_lane(uintptr_t value)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(value));
}
template <typename Object, std::enable_if_t<std::is_trivially_copyable_v<Object>, int> = 0>
__device__ inline auto amd_wave_read_first_lane(const Object& obj)
{

View File

@@ -27,6 +27,7 @@
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"

View File

@@ -0,0 +1,156 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename AccT, typename T>
CK_TILE_HOST_DEVICE constexpr AccT to_acc(T value)
{
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<AccT>(value);
#else
return static_cast<AccT>(
ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value)));
#endif
}
else
{
return static_cast<AccT>(value);
}
}
// Reference implementation: blocked attention (for sparse attention tests).
template <typename T, typename MaskT, typename AccT = float>
void reference_blocked_attention(
const HostTensor<T>& q, // [B, H, S_q, D]
const HostTensor<T>& k, // [B, H, S_k, D]
const HostTensor<T>& v, // [B, H, S_k, D_v]
const HostTensor<MaskT>& block_relation, // [B, H, Q_blocks, K_blocks]
HostTensor<T>& output, // [B, H, S_q, D_v]
index_t BLKQ,
index_t BLKK,
AccT scale)
{
auto q_lengths = q.get_lengths();
index_t batch = q_lengths[0];
index_t nhead = q_lengths[1];
index_t seqlen_q = q_lengths[2];
index_t hdim = q_lengths[3];
auto v_lengths = v.get_lengths();
index_t seqlen_k = v_lengths[2];
index_t hdim_v = v_lengths[3];
index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
for(index_t b = 0; b < batch; ++b)
{
for(index_t h = 0; h < nhead; ++h)
{
for(index_t qb = 0; qb < num_q_blocks; ++qb)
{
index_t q_start = qb * BLKQ;
if(q_start >= seqlen_q)
{
continue;
}
index_t q_end = std::min<index_t>(q_start + BLKQ, seqlen_q);
std::vector<index_t> relevant_k_indices;
for(index_t kb = 0; kb < num_k_blocks; ++kb)
{
// Treat block_relation as boolean; >0.5 marks an active block.
if(static_cast<float>(block_relation(b, h, qb, kb)) > 0.5f)
{
relevant_k_indices.push_back(kb);
}
}
if(relevant_k_indices.empty())
{
continue;
}
for(index_t sq = q_start; sq < q_end; ++sq)
{
std::vector<AccT> scores;
AccT max_score = -std::numeric_limits<AccT>::infinity();
for(auto kb : relevant_k_indices)
{
index_t k_start = kb * BLKK;
if(k_start >= seqlen_k)
{
continue;
}
index_t k_end = std::min<index_t>(k_start + BLKK, seqlen_k);
for(index_t sk = k_start; sk < k_end; ++sk)
{
AccT score = 0.0f;
for(index_t d = 0; d < hdim; ++d)
{
score +=
to_acc<AccT>(q(b, h, sq, d)) * to_acc<AccT>(k(b, h, sk, d));
}
score = score * scale;
scores.push_back(score);
max_score = std::max(max_score, score);
}
}
AccT sum_exp = 0.0f;
for(auto& s : scores)
{
s = std::exp(s - max_score);
sum_exp += s;
}
for(auto& s : scores)
{
s /= sum_exp;
}
for(index_t dv = 0; dv < hdim_v; ++dv)
{
AccT out_val = 0.0f;
size_t score_idx = 0;
for(auto kb : relevant_k_indices)
{
index_t k_start = kb * BLKK;
if(k_start >= seqlen_k)
{
continue;
}
index_t k_end = std::min<index_t>(k_start + BLKK, seqlen_k);
for(index_t sk = k_start; sk < k_end; ++sk)
{
out_val += scores[score_idx] * to_acc<AccT>(v(b, h, sk, dv));
score_idx++;
}
}
output(b, h, sq, dv) = static_cast<T>(out_val);
}
}
}
}
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,13 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp"
#include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp"
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -0,0 +1,446 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdJengaKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using RandValOutputDataType =
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant =
(FmhaPipeline::Problem::QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE);
static_assert(!FmhaPipeline::kIsGroupMode,
"Jenga sparse attention currently supports batch mode only.");
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
"Jenga sparse attention does not support bias.");
static_assert(!kStoreLSE, "Jenga sparse attention does not support LSE output.");
static_assert(!kHasDropout, "Jenga sparse attention does not support dropout.");
static_assert(!kHasLogitsSoftCap, "Jenga sparse attention does not support logits soft-cap.");
static_assert(!kDoFp8StaticQuant,
"Jenga sparse attention does not support FP8 static quantization yet.");
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct FmhaFwdCommonKargs
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* block_relation_onehot_ptr;
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
};
struct FmhaFwdMaskKargs
{
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
};
using Kargs = FmhaFwdBatchModeKargs;
struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};
// std::variant<> can't take in a list initializer, overload for backward compatibility
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* block_relation_onehot_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
block_relation_onehot_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
stride_q,
stride_k,
stride_v,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o}, // FmhaFwdCommonKargs
{}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1>
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o};
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
// if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d",
// int(GetSmemSize()));
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_o = 0;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
// sparse mask
const bool* block_relation_onehot_ptr =
reinterpret_cast<const bool*>(kargs.block_relation_onehot_ptr) +
static_cast<long_index_t>(i_batch * kargs.num_head_q + i_nhead) *
ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) +
i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
else
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV_, kPadSeqLenK>{});
}
}();
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
else
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
AttentionVariant variant;
const auto variant_params = ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
block_relation_onehot_ptr,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,438 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdVSAKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using RandValOutputDataType =
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr bool kDoFp8StaticQuant =
(QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE);
static_assert(!FmhaPipeline::kIsGroupMode, "VSA sparse attention supports batch mode only.");
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
"VSA sparse attention does not support bias.");
static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output.");
static_assert(!kHasDropout, "VSA sparse attention does not support dropout.");
static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap.");
static_assert(!kDoFp8StaticQuant,
"VSA sparse attention does not support FP8 static quantization yet.");
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct FmhaFwdCommonKargs
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* lut_ptr;
const void* valid_block_num_ptr;
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
};
struct FmhaFwdMaskKargs
{
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
};
using Kargs = FmhaFwdBatchModeKargs;
struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};
// std::variant<> can't take in a list initializer, overload for backward compatibility
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* lut_ptr,
const void* valid_block_num_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
lut_ptr,
valid_block_num_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
stride_q,
stride_k,
stride_v,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o}, // FmhaFwdCommonKargs
{}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1>
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o};
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
if constexpr(kHasMask)
{
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_o = 0;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
// sparse mask
const int* lut_ptr =
reinterpret_cast<const int*>(kargs.lut_ptr) +
static_cast<long_index_t>(i_batch * kargs.num_head_q + i_nhead) *
ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) +
i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
const int* valid_block_num_ptr =
reinterpret_cast<const int*>(kargs.valid_block_num_ptr) +
static_cast<long_index_t>(i_batch * kargs.num_head_q + i_nhead) *
ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) +
i_tile_m;
const int valid_block_num_value = valid_block_num_ptr[0];
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
}();
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
else
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
AttentionVariant variant;
const auto variant_params = ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lut_ptr,
valid_block_num_value,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,595 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaPipelineQRKSVSAsyncJenga
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
"Jenga sparse attention does not support bias.");
static_assert(!kHasDropout, "Jenga sparse attention does not support dropout.");
static_assert(!kStoreLSE, "Jenga sparse attention does not support LSE output.");
static_assert(!kHasLogitsSoftCap, "Jenga sparse attention does not support logits soft-cap.");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
if constexpr(kQKHeaddim <= 32)
{
if constexpr(kPadSeqLenK && FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
if constexpr(kPadSeqLenK)
return 2;
else
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(kPadSeqLenK)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 192)
{
if constexpr(kPadSeqLenK)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const bool* block_relation_onehot_ptr,
FmhaMask mask,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumKVLdsBuffers>{});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
const index_t num_block = num_total_loop;
bool* block_relation_onehot = reinterpret_cast<bool*>(smem_ptr) + GetSmemSize();
const index_t thread_offset = static_cast<index_t>(4 * threadIdx.x);
amd_direct_load_global_to_lds<bool, 4>(block_relation_onehot_ptr,
4 * threadIdx.x,
block_relation_onehot,
4 * threadIdx.x,
thread_offset < num_block,
num_block);
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = bool_constant<false>{};
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
buffer_load_fence(1);
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
if(block_relation_onehot[0])
{
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
}
// buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
buffer_load_fence(k_dram_window.get_num_of_access());
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
if(!block_relation_onehot[i_total_loops])
{
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if(block_relation_onehot[i_total_loops])
{
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
}
move_tile_window(k_dram_window, {0, kK0});
move_tile_window(v_dram_window, {0, kN0});
continue;
}
break;
}
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, mask, softmax (no bias/soft-cap)
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, v_shuffle_tmp);
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
const auto p = cast_tile<PDataType>(p_compute);
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
auto v_shuffle_tmp_next = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp_next, v_buf);
auto v_lds_window_tmp_next = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp_next, v_shuffle_tmp_next);
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
return o_acc;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,579 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaPipelineQRKSVSAsyncVSA
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS,
"VSA sparse attention does not support bias.");
static_assert(!kHasDropout, "VSA sparse attention does not support dropout.");
static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output.");
static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap.");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
if constexpr(kQKHeaddim <= 32)
{
if constexpr(kPadSeqLenK && FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
if constexpr(kPadSeqLenK)
return 2;
else
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(kPadSeqLenK)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 192)
{
if constexpr(kPadSeqLenK)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const int* kv_block_idx_ptr,
int kv_blocks,
FmhaMask mask,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumKVLdsBuffers>{});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
int seqlen_k_start = kv_block_idx_ptr[0] * kM0;
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto num_total_loop = kv_blocks;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = bool_constant<false>{};
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
// buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
buffer_load_fence(k_dram_window.get_num_of_access());
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
int block_idx = kv_block_idx_ptr[i_total_loops + 1];
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, mask, softmax (no bias/soft-cap)
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, v_shuffle_tmp);
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, v_buf);
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(p_compute);
else
return cast_tile<PDataType>(p_compute);
}();
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, v_shuffle_tmp);
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, v_buf);
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
move_tile_window(v_dram_window, {0, kN0 * (block_idx - 1)});
move_tile_window(k_dram_block_window, {kN0 * block_idx, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
return o_acc;
}
};
} // namespace ck_tile