From 9317fc4a8508cb53ec9bd829781d4781d84ce428 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Tue, 24 Mar 2026 05:57:54 -0400 Subject: [PATCH] Support 64x128 tile size in sparge fwd for Jenga and VSA paths --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 116 ++- .../codegen/ops/sparge_fwd_jenga.py | 799 ++++++++++++++++++ .../codegen/ops/sparge_fwd_vsa.py | 799 ++++++++++++++++++ .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 6 + .../50_sparse_attn/jenga_sparge_attention.cpp | 189 +++++ .../50_sparse_attn/jenga_sparge_attention.h | 27 + .../test_sparge_jenga_sparse_attn.cpp | 14 +- .../test_sparge_vsa_sparse_attn.cpp | 14 +- .../50_sparse_attn/vsa_sparge_attention.cpp | 195 +++++ .../50_sparse_attn/vsa_sparge_attention.h | 28 + ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 2 +- 11 files changed, 2167 insertions(+), 22 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.h create mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp create mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.h diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index c916f642eb..0ac86f6aff 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -1,8 +1,8 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# CMakeLists.txt for sparse attention (Jenga and VSA) +#Copyright(c) Advanced Micro Devices, Inc., or its affiliates. +#SPDX - License - Identifier : MIT +#CMakeLists.txt for sparse attention(Jenga and VSA) -# Use SUPPORTED_GPU_TARGETS directly +#Use SUPPORTED_GPU_TARGETS directly set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -16,7 +16,7 @@ endif() message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}") -# Code generation scripts +#Code generation scripts file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/generate.py ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py @@ -88,11 +88,62 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# ============================================================================ +# Sparge Jenga (64x128 tile) +# ============================================================================ +set(SPARGE_JENGA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api sparge_fwd_jenga + --receipt 600 +) + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS) + +add_custom_command( + OUTPUT ${SPARGE_JENGA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge Jenga kernels" +) + +message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}") + +set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances") + +add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARGE_JENGA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp +) +target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + # Sparge + Jenga Example executable set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES}) target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template @@ -164,11 +215,62 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# ============================================================================ +# Sparge VSA (64x128 tile) +# ============================================================================ +set(SPARGE_VSA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api sparge_fwd_vsa + --receipt 600 +) + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge VSA kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS) + +add_custom_command( + OUTPUT ${SPARGE_VSA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge VSA kernels" +) + +message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}") + +set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") + +add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARGE_VSA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp +) +target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + # Sparge + VSA Example executable set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES}) target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py new file mode 100644 index 0000000000..872da2326e --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py @@ -0,0 +1,799 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "kernel/fmha_fwd_jenga_kernel.hpp" + +""" + +# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdJengaKernel; + +using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +template<> +float fmha_jenga_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + return fmha_jenga_fwd_(s, a); + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # jenga fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if pipeline.tag != "qr_async": + continue + k = FmhaFwdKernel( + F_idx=2, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py new file mode 100644 index 0000000000..c9a389df3f --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py @@ -0,0 +1,799 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "kernel/fmha_fwd_vsa_kernel.hpp" + +""" + +# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdVSAKernel; + +using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +template<> +float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + return fmha_vsa_fwd_(s, a); + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # vsa fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if pipeline.tag != "qr_async_vsa": + continue + k = FmhaFwdKernel( + F_idx=1, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 7349c3576e..25e3513d2f 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -277,6 +277,9 @@ struct fmha_jenga_fwd_traits float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); +// sparge jenga +float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); + template float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); @@ -322,6 +325,9 @@ using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits; float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); +// sparge vsa +float sparge_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); + template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp new file mode 100644 index 0000000000..88f3e08204 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp @@ -0,0 +1,189 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "jenga_sparge_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include + +template +ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "Jenga sparse attention supports fp16/bf16 only."); + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + log_level, + 0, + 1, + false}; + + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + block_relation_buf.ToDevice(Tblock_relation_onehot.data()); + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqlen_k = shape_seqlen_k; + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + traits.mask_type = mask.type; + }; + + fmha_jenga_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_jenga_fwd_args args; + init_args(args); + + sparge_jenga_fwd(fmha_traits, args, stream_config); + + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); + + return Y; +} + +template ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); + +template ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h new file mode 100644 index 0000000000..6259fcc73c --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp index 0bd664adf6..590e51db14 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp @@ -16,7 +16,7 @@ #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "jenga_sparse_attention.h" +#include "jenga_sparge_attention.h" #include "sparge_tool.hpp" // ============================================================================ @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "20", "benchmark iterations") .insert("kname", "0", "print kernel name") // Sparge-specific - .insert("blkq", "128", "Sparge BLKQ") + .insert("blkq", "64", "Sparge BLKQ") .insert("blkk", "128", "Sparge BLKK") .insert("simthreshd1", "0.6", "Sparge sim threshold") .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") @@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(hdim_v < 0) hdim_v = hdim_q; - if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) { std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, " + std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, " "hdim_q=128, hdim_v=128 only." << std::endl; std::cout << "TEST SKIPPED" << std::endl; @@ -247,7 +247,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { if(kname) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, @@ -268,7 +268,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < warmup; ++i) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, @@ -292,7 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index dd1d3e60be..c0feb23e58 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -16,7 +16,7 @@ #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "jenga_sparse_attention.h" +#include "vsa_sparge_attention.h" #include "sparge_tool.hpp" // ============================================================================ @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "20", "benchmark iterations") .insert("kname", "0", "print kernel name") // Sparge-specific - .insert("blkq", "128", "Sparge BLKQ") + .insert("blkq", "64", "Sparge BLKQ") .insert("blkk", "128", "Sparge BLKK") .insert("simthreshd1", "0.6", "Sparge sim threshold") .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") @@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(hdim_v < 0) hdim_v = hdim_q; - if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) { std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, " + std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, " "hdim_q=128, hdim_v=128 only." << std::endl; std::cout << "TEST SKIPPED" << std::endl; @@ -251,7 +251,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { if(kname) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, @@ -273,7 +273,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < warmup; ++i) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, @@ -298,7 +298,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp new file mode 100644 index 0000000000..5f9c2676dd --- /dev/null +++ b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp @@ -0,0 +1,195 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "vsa_sparge_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include + +template +ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, + const ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "VSA sparse attention supports fp16/bf16 only."); + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + log_level, + 0, + 1, + false}; + + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + lut_buf.ToDevice(TKV_block_idx.data()); + valid_block_num_buf.ToDevice(TKV_blocks.data()); + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqlen_k = shape_seqlen_k; + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + traits.mask_type = mask.type; + }; + + fmha_vsa_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_vsa_fwd_args args; + init_args(args); + + sparge_vsa_fwd(fmha_traits, args, stream_config); + + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); + + return Y; +} + +template ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); + +template ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h new file mode 100644 index 0000000000..d51a7e8c00 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h @@ -0,0 +1,28 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, + const ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level = 0); diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 2b097ae582..578ad7e603 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -200,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - int seqlen_k_start = kv_block_idx_ptr[0] * kM0; + int seqlen_k_start = kv_block_idx_ptr[0] * kN0; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(),