diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 35afb1181e..791ce21c86 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -200,6 +200,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_FMHA_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") # not using add_example_executable() to add this target, since we don't want this to be included in @@ -207,6 +208,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_FMHA_BWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index b20a661805..532285ce5a 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -83,6 +83,7 @@ message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}") add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp) target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_JENGA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template -Wno-float-equal @@ -148,11 +149,64 @@ message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp) target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_VSA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template -Wno-float-equal ) +# ============================================================================ +# Sparge Sparse Attention (PV-skip enabled, derived from VSA) +# ============================================================================ +set(SPARSE_ATTN_SPARGE_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api fwd_sparge + --receipt 600 +) + +# Generate list of Sparge kernels (at configure time, only list) +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt SPARSE_ATTN_SPARGE_GEN_BLOBS) + +# Generate Sparge kernel source files at build time +add_custom_command( + OUTPUT ${SPARSE_ATTN_SPARGE_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge Sparse Attention kernels" +) + +message(STATUS "Sparge kernel files to be generated: ${SPARSE_ATTN_SPARGE_GEN_BLOBS}") + +# Sparge Instances +set(SPARSE_ATTN_SPARGE_INSTANCES "tile_sparse_attn_sparge_instances") + +add_library(${SPARSE_ATTN_SPARGE_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARSE_ATTN_SPARGE_GEN_BLOBS} +) +target_include_directories(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARSE_ATTN_SPARGE_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARSE_ATTN_SPARGE_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + # ============================================================================ # Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) # ============================================================================ @@ -188,9 +242,11 @@ add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp) target_link_libraries(${EXAMPLE_SPARGE} ${SPARSE_ATTN_JENGA_INSTANCES} ${SPARSE_ATTN_VSA_INSTANCES} + ${SPARSE_ATTN_SPARGE_INSTANCES} ${SPARGE_BLOCKMAP_INSTANCES} ) target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_SPARGE} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${EXAMPLE_SPARGE} PRIVATE -Wno-undefined-func-template -Wno-float-equal diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index 8614a1ff3b..0f2866cbf4 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -58,11 +58,13 @@ LAYOUT_MAP = {"row": "true", "col": "false"} PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga", "qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA", + "qr_async_sparge": "ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge", } PIPELINE_ENUM_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_async_sparge": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", } BOOL_MAP = { diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py new file mode 100644 index 0000000000..9489d3758f --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py @@ -0,0 +1,1041 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp" +#include "kernel/fmha_fwd_sparge_kernel.hpp" + +""" + +# NOTE: Sparge sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +// R25 V0: instantiate the Sparge pipeline with kEnablePVSkip = true (existing path) +// AND kEnablePVSkip = false (PV-skip AST removed at compile time, source-equivalent +// to the frozen VSA reference). Both kernels live in the same TU; the host dispatch +// in fmha_sparge_fwd_api.cpp picks one based on fmha_sparge_fwd_args::pv_skip_compile. +using fmha_pipeline_{F_idx}_pvst = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + true>; +using fmha_pipeline_{F_idx}_pvsf = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + false>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx}_pvst = + ck_tile::FmhaFwdSpargeKernel; +using fmha_kernel_{F_idx}_pvsf = + ck_tile::FmhaFwdSpargeKernel; + +using trait_{F_idx} = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +template<> +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvst; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvst" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsf; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvsf" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvst; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsf; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} +""" + +FMHA_FWD_API_FILENAME = "fmha_sparge_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_sparge_fwd(fmha_sparge_fwd_traits t, fmha_sparge_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + if(a.pv_skip_compile) + return fmha_sparge_fwd_(s, a); + else + return fmha_sparge_fwd_(s, a); + }} +""" + +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_sparge_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits t, fmha_sparge_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_sparge_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + if(a.pv_skip_compile) + fmha_sparge_fwd_oneshot_(s, a); + else + fmha_sparge_fwd_oneshot_(s, a); + return; + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_sparge_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) + 128, + 128, + 32, + 128, + 32, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), + ), + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, + 128, + 32, + 128, + 32, + 128, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + CppConstraint("t.bm0 == 64"), + ), + FmhaFwdTileSize( # fmt: skip + 16, + 32, + 64, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 16, + 16, + 32, + 16, + 16, + 32, + -1, + ), + FmhaFwdTileSize( # fmt: skip + 32, + 32, + 128, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + FmhaFwdTileSize( # fmt: skip + 128, + 64, + 32, + 128, + 16, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by Sparge sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # sparge fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( + "qr_async_sparge", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async_sparge", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: Sparge sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: + continue + if pipeline.tag != "qr_async_sparge": + continue + k = FmhaFwdKernel( + F_idx=1, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 384f7bb56d..8339b50389 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -334,3 +334,141 @@ template void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args); void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); + +// sparge: same args as vsa plus a scalar PV-skip threshold (Step 1). +struct fmha_sparge_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk] + const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float pv_threshold; // SpargeAttn §4.4 PV-skip per-Q-tile threshold + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + // R25 V0: select between kEnablePVSkip=true / =false template instantiations + // at host dispatch time. Default true preserves existing behaviour (binary + // shipped pre-R25-V0 only had the true instance). Profiler can flip this to + // false to measure the source-equivalent-to-VSA baseline (`if constexpr` + // removes the entire PV-skip AST). + bool pv_skip_compile = true; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.pv_threshold, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + +template +using fmha_sparge_fwd_traits_ = fmha_jenga_fwd_traits_; + +using fmha_sparge_fwd_traits = fmha_jenga_fwd_traits; + +float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&); + +// R25 V0: kEnablePVSkip is now a template non-type param so the codegen can +// emit both true / false instantiations from the same source tree. The host +// dispatch (fmha_sparge_fwd_api.cpp) selects the right specialization based +// on fmha_sparge_fwd_args::pv_skip_compile at runtime. +template +float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args); + +template +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config&, fmha_sparge_fwd_args); + +void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits, + fmha_sparge_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp index 0442f1de85..06be6215bc 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -264,3 +264,24 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, }, [=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); }); } + +float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t, + sparge_blockmap_args bmap_a, + fmha_sparge_fwd_traits attn_t, + fmha_sparge_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_sparge_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); }, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { fmha_sparge_fwd_oneshot(attn_t, attn_a, s_); }); +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp index 4d0e935fc9..c0178f70de 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -169,3 +169,9 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits, fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); + +float sparge_sparge_fwd_combined(sparge_blockmap_traits, + sparge_blockmap_args, + fmha_sparge_fwd_traits, + fmha_sparge_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index 73e58f0c24..ae0952cc41 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include #include #include @@ -100,7 +102,19 @@ auto create_args(int argc, char* argv[]) .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") - .insert("kname", "0", "print kernel name"); + .insert("kname", "0", "print kernel name") + .insert("dump_o", + "", + "if non-empty, dump raw output buffer bytes to this path (for bit-identical " + "baseline comparison)") + .insert("pv_threshold", + "1e30", + "SpargeAttn PV-skip per-Q-tile threshold; default +1e30 disables skip") + .insert("pv_skip_compile", + "1", + "R25 V0: 1=use kEnablePVSkip=true template instance (existing path); 0=use " + "kEnablePVSkip=false instance (PV-skip AST removed at compile time, equivalent to " + "VSA baseline)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -130,6 +144,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser) int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); int kname = arg_parser.get_int("kname"); + std::string dump_o_path = arg_parser.get_str("dump_o"); + float pv_threshold = arg_parser.get_float("pv_threshold"); + int pv_skip_compile = arg_parser.get_int("pv_skip_compile"); if(nhead_k < 0) nhead_k = nhead; @@ -309,7 +326,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser) } else if(pipeline == "vsa") { - fmha_vsa_fwd_traits attn_traits; + // R25: -pipeline=vsa now dispatches to the sparge pipeline family that adds + // SpargeAttn §4.4 PV-skip; pass pv_threshold (+1e30 disables skip, matches old vsa). + fmha_sparge_fwd_traits attn_traits; attn_traits.hdim_q = hdim_q; attn_traits.hdim_v = hdim_v; attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; @@ -317,7 +336,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.mask_type = mask_enum::no_mask; attn_traits.bm0 = BLKQ; - fmha_vsa_fwd_args attn_args; + fmha_sparge_fwd_args attn_args; attn_args.q_ptr = q_dev.GetDeviceBuffer(); attn_args.k_ptr = k_dev.GetDeviceBuffer(); attn_args.v_ptr = v_dev.GetDeviceBuffer(); @@ -333,6 +352,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_args.nhead_q = nhead; attn_args.nhead_k = nhead_k; attn_args.scale_s = scale_s; + attn_args.pv_threshold = pv_threshold; + attn_args.pv_skip_compile = (pv_skip_compile != 0); attn_args.stride_q = q_strides[i_perm ? 2 : 1]; attn_args.stride_k = k_strides[i_perm ? 2 : 1]; attn_args.stride_v = v_strides[i_perm ? 2 : 1]; @@ -350,7 +371,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_args.mask_type = 0; avg_ms = - sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); } else { @@ -374,6 +395,23 @@ bool run_test(const ck_tile::ArgParser& arg_parser) o_dev.FromDevice(output_host.data()); block_map_dev.FromDevice(block_map_host.data()); + // ---- optional raw output dump (for bit-identical baseline comparison) ---- + if(!dump_o_path.empty()) + { + std::ofstream ofs(dump_o_path, std::ios::binary | std::ios::trunc); + if(!ofs) + { + std::cerr << "\n [dump_o] failed to open " << dump_o_path << std::endl; + } + else + { + ofs.write(reinterpret_cast(output_host.data()), + static_cast(output_host.get_element_space_size_in_bytes())); + std::cout << "\n [dump_o] wrote " << output_host.get_element_space_size_in_bytes() + << " bytes to " << dump_o_path; + } + } + // ---- count active blocks ---- ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks; ck_tile::index_t active_blocks = 0; diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp new file mode 100644 index 0000000000..d600ff7075 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp @@ -0,0 +1,442 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdSpargeKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + static constexpr bool kEnablePVSkip = kEnablePVSkip_; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; + static constexpr bool kDoFp8StaticQuant = + (QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); + static_assert(!FmhaPipeline::kIsGroupMode, "Sparge sparse attention supports batch mode only."); + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "Sparge sparse attention does not support bias."); + static_assert(!kStoreLSE, "Sparge sparse attention does not support LSE output."); + static_assert(!kHasDropout, "Sparge sparse attention does not support dropout."); + static_assert(!kHasLogitsSoftCap, "Sparge sparse attention does not support logits soft-cap."); + static_assert(!kDoFp8StaticQuant, + "Sparge sparse attention does not support FP8 static quantization yet."); + + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lut_ptr; + const void* valid_block_num_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + float pv_threshold; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdMaskKargs + { + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + using Kargs = FmhaFwdBatchModeKargs; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + // std::variant<> can't take in a list initializer, overload for backward compatibility + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float pv_threshold, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + pv_threshold, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // FmhaFwdCommonKargs + {}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1> + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_o = 0; + + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + + // sparse mask + const int* lut_ptr = + reinterpret_cast(kargs.lut_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * + ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + const int* valid_block_num_ptr = + reinterpret_cast(kargs.valid_block_num_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * + ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + + i_tile_m; + const int valid_block_num_value = valid_block_num_ptr[0]; + + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + AttentionVariant variant; + const auto variant_params = ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lut_ptr, + valid_block_num_value, + mask, + kargs.scale_s, + kargs.pv_threshold, + variant, + variant_params, + block_indices, + smem_ptr); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp new file mode 100644 index 0000000000..4bf3c9d296 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp @@ -0,0 +1,698 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA; +// adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original +// _vsa.hpp can remain frozen as an A/B baseline. +// +// QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs; +// _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim. +template +struct BlockFmhaPipelineQRKSVSAsyncSparge +{ + static constexpr bool kEnablePVSkip = kEnablePVSkip_; + + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "VSA sparse attention does not support bias."); + static_assert(!kHasDropout, "VSA sparse attention does not support dropout."); + static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output."); + static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap."); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + if constexpr(kQKHeaddim <= 32) + { + if constexpr(kPadSeqLenK && FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + if constexpr(kPadSeqLenK) + return 2; + else + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(kPadSeqLenK) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 192) + { + if constexpr(kPadSeqLenK) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const int* kv_block_idx_ptr, + int kv_blocks, + FmhaMask mask, + float scale_s, + float pv_threshold, // SpargeAttn PV-skip threshold; see §2 of pv_skip plan + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr) const + { + if constexpr(!kEnablePVSkip) + { + (void)pv_threshold; // silence unused-param when PV-skip is compiled out + } + // R25 Step 1 redesign D: PV-skip control is a compile-time gate + // (kEnablePVSkip). The entire PV-skip logic block below is wrapped in + // `if constexpr (kEnablePVSkip)`, so when this template parameter is + // false the AST contains no vote, no scalar gate, no extra LDS, and + // codegen converges with _vsa.hpp's FmhaFwdVSAKernel. + // + // Runtime fast-path (C3-lite): pv_threshold == +1e30 sentinel disables + // the skip at runtime via one scalar branch (sgpr); kept inside the + // `if constexpr` so the OFF instantiation pays zero cost. + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + int seqlen_k_start = kv_block_idx_ptr[0] * kN0; + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto num_total_loop = kv_blocks; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = bool_constant{}; + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + // buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + buffer_load_fence(k_dram_window.get_num_of_access()); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + int block_idx = kv_block_idx_ptr[i_total_loops + 1]; + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0( + s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, mask, softmax (no bias/soft-cap) +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store + // Only needed when K tail and V use the same LDS buffer + if constexpr(LdsSeq.at(number{}) == LdsSeq.at(number{})) + { + __builtin_amdgcn_s_barrier(); + } + // store & prefetch next v, after the max reduction. + // R25 Step 1 redesign D: V→LDS store and the next-V DRAM load are + // UNCONDITIONAL — per-warp PV-skip cannot gate them (cross-warp + // shared LDS state; see Researcher audit A.3/A.4/A.5). + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile(v_lds_window_tmp, v_shuffle_tmp); + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, v_buf); + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + // ================================================================ + // PV-SKIP per Q-tile (SpargeAttn paper §4.4) + // R25 Step 1 redesign D — per-warp arithmetic-only: + // Compile-time `if constexpr (kEnablePVSkip)` wraps the entire + // block. When kEnablePVSkip=false the AST has zero PV-skip + // artifacts → codegen converges with _vsa.hpp. + // + // When enabled, a per-warp predicate gates ONLY the per-row, + // VGPR-private work (exp2 → p_compute, rowsum, `l += rowsum_p`). + // V load / V→LDS store / gemm_1 / every `s_barrier` / + // `block_sync_lds` stay unconditional (cross-warp LDS dep — see + // Researcher audit A.7). + // + // On warp_skip, this warp's owned rows of p_compute are zeroed + // so the unconditional gemm_1 contributes 0 to o_acc (audit + // A.7 "simplest realisation"). The alpha-rescale `l *= tmp` and + // `o *= tmp` still apply. + // + // pv_threshold semantics shift: now per-warp max diff (slightly + // more aggressive than per-block at the same threshold; matches + // upstream SpargeAttn `kPerWarp` mode default). + // + // Skip iff: scale_s * (m_local - m_old) + pv_threshold <= 0 + // (where m_local/m_old are warp-uniform after block_tile_reduce_sync) + // ================================================================ + // Per-warp PV-skip predicate. Only declared when kEnablePVSkip; + // wrapped in a lambda so the false instantiation contains nothing. + auto compute_warp_skip = [&]() { + if constexpr(kEnablePVSkip) + { + // C3-lite scalar fast-path: pv_threshold == +1e30 sentinel + // disables skip; runtime cost is a single sgpr branch. + if(pv_threshold >= 1e29f) + return false; + // Per-row predicate: warp-AND over rows this warp owns. + int warp_skip_int = 1; + constexpr auto m_spans = decltype(m_local)::get_distributed_spans(); + sweep_tile_span(m_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const float diff = scale_s * (static_cast(m_local[i_idx]) - + static_cast(m_old[i_idx])); + if(!(diff + pv_threshold <= 0.0f)) + warp_skip_int = 0; + }); + // Warp-level AND reduce (wave=64 on gfx942; xor butterfly). + // No LDS, no s_barrier, no cross-warp dependency. + warp_skip_int &= __shfl_xor(warp_skip_int, 32); + warp_skip_int &= __shfl_xor(warp_skip_int, 16); + warp_skip_int &= __shfl_xor(warp_skip_int, 8); + warp_skip_int &= __shfl_xor(warp_skip_int, 4); + warp_skip_int &= __shfl_xor(warp_skip_int, 2); + warp_skip_int &= __shfl_xor(warp_skip_int, 1); + return warp_skip_int != 0; + } + else + { + return false; + } + }; + const bool warp_skip = compute_warp_skip(); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + // exp2 → p_compute and rowsum_p. + // R25 redesign D: when kEnablePVSkip + warp_skip, we zero this + // warp's owned rows of p_compute so the unconditional gemm_1 + // contributes zero to o_acc, and skip the rowsum. + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(kEnablePVSkip) + { + if(warp_skip) + { + p_compute(i_j_idx) = SMPLComputeDataType{0}; + return; + } + } +#if CK_TILE_FMHA_FWD_FAST_EXP2 + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // l{j}, Oacc{j}: alpha rescale of l / o always runs. + // When warp_skip, rowsum_p is already 0 for this + // warp's owned rows (p_compute zeroed above), so + // `l += rowsum_p` is a no-op — no extra branch needed. + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pkrtz_fp16_fp32(p_compute); + else + return cast_tile(p_compute); + }(); + + // STAGE 3, KV gemm — always runs (block-wide LDS dep; per-warp + // skipping has been absorbed by zeroing p_compute rows above). + { + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window, + number<-1>{}, + bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{}); + store_tile(v_lds_window_tmp, v_shuffle_tmp); + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{}); + store_tile(v_lds_window_tmp, v_buf); + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // V load runs unconditionally under redesign D, so no skip + // compensation needed (same offset arithmetic as _vsa.hpp). + move_tile_window(v_dram_window, {0, kN0 * (block_idx - 1)}); + move_tile_window(k_dram_block_window, {kN0 * block_idx, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail — gemm_1 runs unconditionally under redesign D. + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } +}; + +} // namespace ck_tile