From ab44b835667e29cf5ba844d2a7bdd52fc9f4cc17 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Wed, 22 Apr 2026 13:13:37 -0400 Subject: [PATCH] refactor to combine two kernel --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 131 +-- .../codegen/ops/fmha_fwd_jenga.py | 141 +++- .../codegen/ops/fmha_fwd_vsa.py | 141 +++- .../codegen/ops/sparge_fwd_jenga.py | 799 ------------------ .../codegen/ops/sparge_fwd_vsa.py | 799 ------------------ .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 16 +- .../50_sparse_attn/jenga_sparge_attention.cpp | 189 ----- .../50_sparse_attn/jenga_sparge_attention.h | 27 - .../50_sparse_attn/sparge_blockmap_inst.cpp | 139 +++ .../50_sparse_attn/sparge_blockmap_trek.hpp | 13 + .../ck_tile/50_sparse_attn/test_sparge.cpp | 432 ++++++++++ .../test_sparge_jenga_sparse_attn.cpp | 422 --------- .../test_sparge_vsa_sparse_attn.cpp | 597 ------------- ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 40 +- 14 files changed, 896 insertions(+), 2990 deletions(-) delete mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py delete mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py delete mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp delete mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.h create mode 100644 example/ck_tile/50_sparse_attn/test_sparge.cpp delete mode 100644 example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp delete mode 100644 example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index f234f631b6..b20a661805 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -88,68 +88,6 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) -# ============================================================================ -# Sparge Jenga (64x128 tile) -# ============================================================================ -set(SPARGE_JENGA_CODE_GEN_ARGS - ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api sparge_fwd_jenga - --receipt 600 -) - -execute_process( - COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list") -endif() - -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS) - -add_custom_command( - OUTPUT ${SPARGE_JENGA_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} - --output_dir ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${CODE_GEN_SCRIPTS} - COMMENT "Generate CK Tile Sparge Jenga kernels" -) - -message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}") - -set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances") - -add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL - ${SPARGE_JENGA_GEN_BLOBS} - ${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp -) -target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn -) -set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP) -set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) - -target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE - -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN - -DCK_TILE_FMHA_FWD_FAST_EXP2 - -Wno-undefined-func-template - -Wno-float-equal -) - -# Sparge + Jenga Example executable -set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") -message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") -add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES}) -target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE - -Wno-undefined-func-template - -Wno-float-equal -) - # ============================================================================ # VSA Sparse Attention # ============================================================================ @@ -215,55 +153,6 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) -# ============================================================================ -# Sparge VSA (64x128 tile) -# ============================================================================ -set(SPARGE_VSA_CODE_GEN_ARGS - ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api sparge_fwd_vsa - --receipt 600 -) - -execute_process( - COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to generate Sparge VSA kernel list") -endif() - -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS) - -add_custom_command( - OUTPUT ${SPARGE_VSA_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} - --output_dir ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${CODE_GEN_SCRIPTS} - COMMENT "Generate CK Tile Sparge VSA kernels" -) - -message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}") - -set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") - -add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL - ${SPARGE_VSA_GEN_BLOBS} -) -target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn -) -set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) - -target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE - -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN - -DCK_TILE_FMHA_FWD_FAST_EXP2 - -Wno-undefined-func-template - -Wno-float-equal -) - # ============================================================================ # Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) # ============================================================================ @@ -289,16 +178,20 @@ target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE -Wno-float-equal ) -# Sparge + VSA Example executable (now links blockmap kernel too) -set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") -message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") -add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} - ${SPARGE_VSA_INSTANCES} +# ---------------------------------------------------------------------------- +# Build unified Sparge test: combines blockmap, Jenga, and VSA attention +# for end-to-end evaluation and timing in a single executable. +# ---------------------------------------------------------------------------- +set(EXAMPLE_SPARGE "tile_example_sparge") +message(DEBUG "adding example ${EXAMPLE_SPARGE}") +add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp) +target_link_libraries(${EXAMPLE_SPARGE} + ${SPARSE_ATTN_JENGA_INSTANCES} + ${SPARSE_ATTN_VSA_INSTANCES} ${SPARGE_BLOCKMAP_INSTANCES} ) -target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE +target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE} PRIVATE -Wno-undefined-func-template -Wno-float-equal ) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index a3d32652a9..1f0a78048d 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -141,6 +141,17 @@ float fmha_jenga_fwd_(const ck_tile::stream_config& s, fmha_jenga constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp" @@ -219,6 +230,45 @@ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayo }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_jenga_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_jenga_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_jenga_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ class FmhaFwdApiTrait: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ class FmhaFwdApiPool: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,6 +690,27 @@ class KernelComponentFactory: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), FmhaFwdTileSize( # fmt: skip 16, 32, @@ -780,7 +909,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 != 64 or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async": continue @@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +995,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 038738de24..217cfcfe2a 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -141,6 +141,17 @@ float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_vsa_fwd constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp" @@ -219,6 +230,45 @@ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayo }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_vsa_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_vsa_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_vsa_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ class FmhaFwdApiTrait: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ class FmhaFwdApiPool: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,6 +690,27 @@ class KernelComponentFactory: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), FmhaFwdTileSize( # fmt: skip 16, 32, @@ -780,7 +909,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 != 64 or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async_vsa": continue @@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +995,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py deleted file mode 100644 index 872da2326e..0000000000 --- a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py +++ /dev/null @@ -1,799 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass, field -import fnmatch -import itertools -import os -import os.path as path -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cpp_symbol_map import ( - BOOL_MAP, - FWD_DTYPE_MAP, - LAYOUT_MAP, - MODE_MAP, - PIPELINE_ENUM_MAP, - PIPELINE_MAP, - get_mask_check_map, - get_mask_map, -) - -GEN_DIR = "" - - -def update_file(file_path, content): - """Update the file at file_path with the given content if it differs from the existing content. - - It avoids unnecessary touching of the file which triggers rebuilds - """ - - existing_content = "" - if path.exists(file_path): - with open(file_path, "r") as file: - existing_content = file.read() - if existing_content == content: - return - with open(file_path, "w") as file: - file.write(content) - - -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} - -K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd_trek.hpp" -#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" -#include "kernel/fmha_fwd_jenga_kernel.hpp" - -""" - -# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert: -# - Group mode: NOT supported (batch mode only) -# - Bias: NOT supported (NO_BIAS only) -# - LSE output: NOT supported (false only) -# - Dropout: NOT supported (false only) -# - Logits soft-cap: NOT supported (false only) -# - FP8 static quantization: NOT supported (NO_SCALE only) -# The template below hardcodes these unsupported features accordingly. - -FMHA_FWD_KERNEL_BODY = """ -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, -// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - false, // has_logits_soft_cap - NOT supported - ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported - false, // store_lse - NOT supported - false, // has_dropout - NOT supported - false, // has_randval - NOT supported - ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported - {F_occupancy}, - false>; - -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaSparseFwdTypeConfig::QDataType, - typename FmhaSparseFwdTypeConfig::KDataType, - typename FmhaSparseFwdTypeConfig::VDataType, - typename FmhaSparseFwdTypeConfig::SaccDataType, - typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, - typename FmhaSparseFwdTypeConfig::BiasDataType, - typename FmhaSparseFwdTypeConfig::RandValOutputDataType, - typename FmhaSparseFwdTypeConfig::LSEDataType, - typename FmhaSparseFwdTypeConfig::PDataType, - typename FmhaSparseFwdTypeConfig::OaccDataType, - typename FmhaSparseFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - {F_trload}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdJengaKernel; - -using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - -#include - -template<> -float fmha_jenga_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << "{F_kernel_name}" << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp" -FMHA_FWD_API = """ -#include - -#include - -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device"); - return false; - }} - - hipDeviceProp_t props{{}}; - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device properties"); - return false; - }} - - num_cus = props.multiProcessorCount; - return true; -}} - -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ - const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 - - return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; - - [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate - - unsigned num_cus; - if (!get_num_cus(num_cus)) {{ - return r; - }} - - [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ - return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); - }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); - -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ -{F_dtype_case} - }} -""" - -FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - return fmha_jenga_fwd_(s, a); - }} -""" - - -@dataclass -class CppConstraint: - bool_expr: str = None - - def __str__(self): - if self.bool_expr is None: - return "true" - else: - return f"{self.bool_expr}" - - def __and__(self, other): - return CppConstraint(f"({str(self)}) && ({str(other)})") - - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag: str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim: str - dtype: str # data type - mode: str # value from MODE_MAP - bm0: int # tile size along q seqlen (block size) - bn0: int # tile size along qk seqlen - bk0: int # tile size along qk gemm unroll - bn1: int # tile size along v head_dim - bk1: int # tile size along kv gemm unroll - bk0max: int - vlayout: str - logits: str - mask: str - spad: str - skpad: str - dpad: str - dvpad: str - tr_load: str - constraint: CppConstraint - - @property - def name(self) -> str: - return ( - f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" - ) - - @property - def scheck(self) -> str: - if self.mode == "group": - return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.spad == "t": - return "true" # always support - return "true" - - @property - def seqtune(self) -> str: - return "true" - - @property - def skcheck(self) -> str: - if self.mode == "group": - return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true - if self.skpad == "t": - return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" - return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" - - @property - def dcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == "t": - return f"a.hdim_q % {vec} == 0" - assert False - - @property - def dvcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == "t": - return f"a.hdim_v % {vec} == 0" - assert False - - -@dataclass -class FmhaFwdPipeline: - tag: str - - F_vlayout: str # row/col - F_spad: str # true/false - F_skpad: str # - F_dpad: str # - F_dvpad: str # - F_logits: str # t/f - F_mask: str # value from MASK_MAP - F_trload: str # true/false - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - def pad_name() -> str: - n = "" - if self.F_spad == "t": - n += "s" - if self.F_skpad == "t": - n += "sk" - if self.F_dpad == "t": - n += "d" - if self.F_dvpad == "t": - n += "dv" - if n != "": - n = "p" + n - return n - - pn = pad_name() - n = f"{self.tag}_v{self.F_vlayout[0]}" - if pn != "": - n += f"_{pn}" - else: - n += "_npad" - - if self.F_logits == "t": - n += "_logits" - else: - n += "_nlogits" - - n += "_nbias" - - if self.F_mask[0:2] == "s_": - if self.F_mask == "s_mask": - n += "_mask" - else: - n += "_nmask" - else: - if self.F_mask != "no": - n += f"_m{self.F_mask[0]}" - else: - n += "_nmask" - - n += "_nskip" - - n += "_nsquant" - - if self.F_trload == "t": - n += "_trload" - else: - n += "_ntrload" - - return n - - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait: FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - hdim = trait.hdim, trait.bn1 - if hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][hdim] = list() - - self.pool[trait.dtype][hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - tr_load_cond_map = {"t": "has_load_tr", "f": "true"} - - per_tr_load = str() - for tr_load in ["t", "f"]: - per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case = str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits = [ - t - for t in self.pool[dtype][(hdim, hdim_v)] - if tr_load == t.tr_load - ] - inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( - F_if=if_k, - F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - # F_logits removed - hardcoded to false (NOT supported) - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_trload=BOOL_MAP[trait.tr_load], - F_scheck=trait.scheck, - F_seqtune=trait.seqtune, - F_skcheck=trait.skcheck, - F_dcheck=trait.dcheck, - F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], - F_skpad=BOOL_MAP[trait.skpad], - F_dpad=BOOL_MAP[trait.dpad], - F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, - F_bn0=trait.bn0, - F_bk0=trait.bk0, - F_bn1=trait.bn1, - F_bk1=trait.bk1, - F_bk0max=trait.bk0max, - F_hdim=hdim, - F_dtype=FWD_DTYPE_MAP[dtype], - ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners - ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case - ) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( - F_if="if", - F_trload_cond=tr_load_cond_map[tr_load], - F_dtype_case=per_dtypes, - ) - if not per_tr_load: - # empty string we add some ignore to suppress warning in api - per_tr_load += " (void)t ; (void)s ; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) - - -@dataclass -class FmhaFwdTileSize: - F_bm0: int # tile size along q seqlen (block size) - F_bn0: int # tile size along k seqlen - F_bk0: int # tile size along qk gemm unroll - F_bn1: int # tile size along v head_dim - F_bk1: int # tile size along kv gemm unroll - F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0: int # number of warps for gemm0 along q seqlen - F_rn0: int # number of warps for gemm0 along k seqlen - F_rk0: int # number of warps for gemm0 along head dim q (not used) - F_rm1: int # number of warps for gemm1 along q seqlen - F_rn1: int # number of warps for gemm1 along head dim v - F_rk1: int # number of warps for gemm1 along k seqlen (not used) - F_wm0: int # gemm0 warp size along m - F_wn0: int # gemm0 warp size along n - F_wk0: int # gemm0 warp size along k - F_wm1: int # gemm1 warp size along m - F_wn1: int # gemm1 warp size along n - F_wk1: int # gemm1 warp size along k - F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - return ( - f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" - + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" - + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" - + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - ) - - -@dataclass -class FmhaFwdKernel: - F_idx: int # this is not a tunable, but a counter to differentiate symbol - F_hdim: int # hdim - F_dtype: str # data type - F_mode: str # value from MODE_MAP - F_tile: FmhaFwdTileSize - F_pipeline: FmhaFwdPipeline - mask_impl: str - - @property - def template(self) -> str: - # kernel_body removed - unused - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, - F_hdim=self.F_hdim, - F_dtype=FWD_DTYPE_MAP[self.F_dtype], - F_bm0=self.F_tile.F_bm0, - F_bn0=self.F_tile.F_bn0, - F_bk0=self.F_tile.F_bk0, - F_bn1=self.F_tile.F_bn1, - F_bk1=self.F_tile.F_bk1, - F_bk0max=self.F_tile.F_bk0max, - F_rm0=self.F_tile.F_rm0, - F_rn0=self.F_tile.F_rn0, - F_rk0=self.F_tile.F_rk0, - F_rm1=self.F_tile.F_rm1, - F_rn1=self.F_tile.F_rn1, - F_rk1=self.F_tile.F_rk1, - F_wm0=self.F_tile.F_wm0, - F_wn0=self.F_tile.F_wn0, - F_wk0=self.F_tile.F_wk0, - F_wm1=self.F_tile.F_wm1, - F_wn1=self.F_tile.F_wn1, - F_wk1=self.F_tile.F_wk1, - F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad=BOOL_MAP[self.F_pipeline.F_spad], - F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], - # F_logits removed - hardcoded to false in template (NOT supported) - F_occupancy=self.F_tile.F_occupancy, - F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], - F_trload=BOOL_MAP[self.F_pipeline.F_trload], - F_kernel_name=self.name, - ) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return ( - f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" - + self.F_tile.name - + "_" - + self.F_pipeline.name - ) - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, - ) - - -class KernelComponentFactory: - # TODO: design a more practical way to do it - # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128, 128): [ - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - ], - # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } - else: - return None - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert) - pipelines = [] - if dtype in ["fp16", "bf16"]: - for logits, mask in itertools.product( - ["f"], # logits soft-cap NOT supported, always false - get_mask_map(mask_impl).keys(), - ): - if hdim == 256 and hdim_v == 256: - # jenga fmha only supports dim <= 192 for now. - continue - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr_async", - "row", - "t", - "f", - "t", - "t", - logits, - mask, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr_async", - "row", - "t", - "t", - "t", - "t", - logits, - mask, - "f", - ) - ) - else: - assert False - return pipelines - - -class CustomFactory(KernelComponentFactory): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": - if (128, 128) in result.keys(): - result[(128, 128)].insert( - 0, - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - CppConstraint( - "get_num_blocks(128) < num_cus * min_cu_util_rate" - ), - ), - ) - return result - - -def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl -) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - factory = ( - CustomFactory - if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" - else KernelComponentFactory - ) - - # Only generate fp16/bf16 kernels for now. - # NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) - for dtype in ["fp16", "bf16"]: - d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): - for tile, pipeline in itertools.product( - tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) - ): - if pipeline.tag != "qr_async": - continue - k = FmhaFwdKernel( - F_idx=2, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= mode == "batch" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) - - -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) - - -def write_blobs( - output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - - -def list_blobs( - file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py deleted file mode 100644 index c9a389df3f..0000000000 --- a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py +++ /dev/null @@ -1,799 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass, field -import fnmatch -import itertools -import os -import os.path as path -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cpp_symbol_map import ( - BOOL_MAP, - FWD_DTYPE_MAP, - LAYOUT_MAP, - MODE_MAP, - PIPELINE_ENUM_MAP, - PIPELINE_MAP, - get_mask_check_map, - get_mask_map, -) - -GEN_DIR = "" - - -def update_file(file_path, content): - """Update the file at file_path with the given content if it differs from the existing content. - - It avoids unnecessary touching of the file which triggers rebuilds - """ - - existing_content = "" - if path.exists(file_path): - with open(file_path, "r") as file: - existing_content = file.read() - if existing_content == content: - return - with open(file_path, "w") as file: - file.write(content) - - -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} - -K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd_trek.hpp" -#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" -#include "kernel/fmha_fwd_vsa_kernel.hpp" - -""" - -# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert: -# - Group mode: NOT supported (batch mode only) -# - Bias: NOT supported (NO_BIAS only) -# - LSE output: NOT supported (false only) -# - Dropout: NOT supported (false only) -# - Logits soft-cap: NOT supported (false only) -# - FP8 static quantization: NOT supported (NO_SCALE only) -# The template below hardcodes these unsupported features accordingly. - -FMHA_FWD_KERNEL_BODY = """ -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, -// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - false, // has_logits_soft_cap - NOT supported - ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported - false, // store_lse - NOT supported - false, // has_dropout - NOT supported - false, // has_randval - NOT supported - ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported - {F_occupancy}, - false>; - -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaSparseFwdTypeConfig::QDataType, - typename FmhaSparseFwdTypeConfig::KDataType, - typename FmhaSparseFwdTypeConfig::VDataType, - typename FmhaSparseFwdTypeConfig::SaccDataType, - typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, - typename FmhaSparseFwdTypeConfig::BiasDataType, - typename FmhaSparseFwdTypeConfig::RandValOutputDataType, - typename FmhaSparseFwdTypeConfig::LSEDataType, - typename FmhaSparseFwdTypeConfig::PDataType, - typename FmhaSparseFwdTypeConfig::OaccDataType, - typename FmhaSparseFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - {F_trload}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdVSAKernel; - -using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - -#include - -template<> -float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << "{F_kernel_name}" << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp" -FMHA_FWD_API = """ -#include - -#include - -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device"); - return false; - }} - - hipDeviceProp_t props{{}}; - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device properties"); - return false; - }} - - num_cus = props.multiProcessorCount; - return true; -}} - -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ - const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 - - return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; - - [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate - - unsigned num_cus; - if (!get_num_cus(num_cus)) {{ - return r; - }} - - [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ - return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); - }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); - -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ -{F_dtype_case} - }} -""" - -FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - return fmha_vsa_fwd_(s, a); - }} -""" - - -@dataclass -class CppConstraint: - bool_expr: str = None - - def __str__(self): - if self.bool_expr is None: - return "true" - else: - return f"{self.bool_expr}" - - def __and__(self, other): - return CppConstraint(f"({str(self)}) && ({str(other)})") - - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag: str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim: str - dtype: str # data type - mode: str # value from MODE_MAP - bm0: int # tile size along q seqlen (block size) - bn0: int # tile size along qk seqlen - bk0: int # tile size along qk gemm unroll - bn1: int # tile size along v head_dim - bk1: int # tile size along kv gemm unroll - bk0max: int - vlayout: str - logits: str - mask: str - spad: str - skpad: str - dpad: str - dvpad: str - tr_load: str - constraint: CppConstraint - - @property - def name(self) -> str: - return ( - f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" - ) - - @property - def scheck(self) -> str: - if self.mode == "group": - return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.spad == "t": - return "true" # always support - return "true" - - @property - def seqtune(self) -> str: - return "true" - - @property - def skcheck(self) -> str: - if self.mode == "group": - return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true - if self.skpad == "t": - return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" - return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" - - @property - def dcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == "t": - return f"a.hdim_q % {vec} == 0" - assert False - - @property - def dvcheck(self) -> str: - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == "t": - return f"a.hdim_v % {vec} == 0" - assert False - - -@dataclass -class FmhaFwdPipeline: - tag: str - - F_vlayout: str # row/col - F_spad: str # true/false - F_skpad: str # - F_dpad: str # - F_dvpad: str # - F_logits: str # t/f - F_mask: str # value from MASK_MAP - F_trload: str # true/false - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - def pad_name() -> str: - n = "" - if self.F_spad == "t": - n += "s" - if self.F_skpad == "t": - n += "sk" - if self.F_dpad == "t": - n += "d" - if self.F_dvpad == "t": - n += "dv" - if n != "": - n = "p" + n - return n - - pn = pad_name() - n = f"{self.tag}_v{self.F_vlayout[0]}" - if pn != "": - n += f"_{pn}" - else: - n += "_npad" - - if self.F_logits == "t": - n += "_logits" - else: - n += "_nlogits" - - n += "_nbias" - - if self.F_mask[0:2] == "s_": - if self.F_mask == "s_mask": - n += "_mask" - else: - n += "_nmask" - else: - if self.F_mask != "no": - n += f"_m{self.F_mask[0]}" - else: - n += "_nmask" - - n += "_nskip" - - n += "_nsquant" - - if self.F_trload == "t": - n += "_trload" - else: - n += "_ntrload" - - return n - - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait: FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - hdim = trait.hdim, trait.bn1 - if hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][hdim] = list() - - self.pool[trait.dtype][hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - tr_load_cond_map = {"t": "has_load_tr", "f": "true"} - - per_tr_load = str() - for tr_load in ["t", "f"]: - per_dtypes = str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case = str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits = [ - t - for t in self.pool[dtype][(hdim, hdim_v)] - if tr_load == t.tr_load - ] - inners = str() - for k, trait in enumerate(traits): - if_k = "if" if k == 0 else "else if" - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( - F_if=if_k, - F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - # F_logits removed - hardcoded to false (NOT supported) - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_trload=BOOL_MAP[trait.tr_load], - F_scheck=trait.scheck, - F_seqtune=trait.seqtune, - F_skcheck=trait.skcheck, - F_dcheck=trait.dcheck, - F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], - F_skpad=BOOL_MAP[trait.skpad], - F_dpad=BOOL_MAP[trait.dpad], - F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, - F_bn0=trait.bn0, - F_bk0=trait.bk0, - F_bn1=trait.bn1, - F_bk1=trait.bk1, - F_bk0max=trait.bk0max, - F_hdim=hdim, - F_dtype=FWD_DTYPE_MAP[dtype], - ) - if_j = "if" if j == 0 else "else if" - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( - F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners - ) - if_i = "if" if i == 0 else "else if" - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case - ) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( - F_if="if", - F_trload_cond=tr_load_cond_map[tr_load], - F_dtype_case=per_dtypes, - ) - if not per_tr_load: - # empty string we add some ignore to suppress warning in api - per_tr_load += " (void)t ; (void)s ; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) - - -@dataclass -class FmhaFwdTileSize: - F_bm0: int # tile size along q seqlen (block size) - F_bn0: int # tile size along k seqlen - F_bk0: int # tile size along qk gemm unroll - F_bn1: int # tile size along v head_dim - F_bk1: int # tile size along kv gemm unroll - F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0: int # number of warps for gemm0 along q seqlen - F_rn0: int # number of warps for gemm0 along k seqlen - F_rk0: int # number of warps for gemm0 along head dim q (not used) - F_rm1: int # number of warps for gemm1 along q seqlen - F_rn1: int # number of warps for gemm1 along head dim v - F_rk1: int # number of warps for gemm1 along k seqlen (not used) - F_wm0: int # gemm0 warp size along m - F_wn0: int # gemm0 warp size along n - F_wk0: int # gemm0 warp size along k - F_wm1: int # gemm1 warp size along m - F_wn1: int # gemm1 warp size along n - F_wk1: int # gemm1 warp size along k - F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint: CppConstraint = field(default_factory=CppConstraint) - - @property - def name(self) -> str: - return ( - f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" - + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" - + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" - + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - ) - - -@dataclass -class FmhaFwdKernel: - F_idx: int # this is not a tunable, but a counter to differentiate symbol - F_hdim: int # hdim - F_dtype: str # data type - F_mode: str # value from MODE_MAP - F_tile: FmhaFwdTileSize - F_pipeline: FmhaFwdPipeline - mask_impl: str - - @property - def template(self) -> str: - # kernel_body removed - unused - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, - F_hdim=self.F_hdim, - F_dtype=FWD_DTYPE_MAP[self.F_dtype], - F_bm0=self.F_tile.F_bm0, - F_bn0=self.F_tile.F_bn0, - F_bk0=self.F_tile.F_bk0, - F_bn1=self.F_tile.F_bn1, - F_bk1=self.F_tile.F_bk1, - F_bk0max=self.F_tile.F_bk0max, - F_rm0=self.F_tile.F_rm0, - F_rn0=self.F_tile.F_rn0, - F_rk0=self.F_tile.F_rk0, - F_rm1=self.F_tile.F_rm1, - F_rn1=self.F_tile.F_rn1, - F_rk1=self.F_tile.F_rk1, - F_wm0=self.F_tile.F_wm0, - F_wn0=self.F_tile.F_wn0, - F_wk0=self.F_tile.F_wk0, - F_wm1=self.F_tile.F_wm1, - F_wn1=self.F_tile.F_wn1, - F_wk1=self.F_tile.F_wk1, - F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad=BOOL_MAP[self.F_pipeline.F_spad], - F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], - # F_logits removed - hardcoded to false in template (NOT supported) - F_occupancy=self.F_tile.F_occupancy, - F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], - F_trload=BOOL_MAP[self.F_pipeline.F_trload], - F_kernel_name=self.name, - ) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return ( - f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" - + self.F_tile.name - + "_" - + self.F_pipeline.name - ) - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, - ) - - -class KernelComponentFactory: - # TODO: design a more practical way to do it - # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype == "fp16" or dtype == "bf16": - return { - # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128, 128): [ - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - ), - ], - # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } - else: - return None - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert) - pipelines = [] - if dtype in ["fp16", "bf16"]: - for logits, mask in itertools.product( - ["f"], # logits soft-cap NOT supported, always false - get_mask_map(mask_impl).keys(), - ): - if hdim == 256 and hdim_v == 256: - # vsa fmha only supports dim <= 192 for now. - continue - pipelines.append( - FmhaFwdPipeline( - "qr_async_vsa", - "row", - "t", - "f", - "t", - "t", - logits, - mask, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_async_vsa", - "row", - "t", - "t", - "t", - "t", - logits, - mask, - "f", - ) - ) - else: - assert False - return pipelines - - -class CustomFactory(KernelComponentFactory): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": - if (128, 128) in result.keys(): - result[(128, 128)].insert( - 0, - FmhaFwdTileSize( - 64, - 128, - 64, - 128, - 64, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 16, - 16, - 16, - 16, - 16, - 16, - -1, - CppConstraint( - "get_num_blocks(128) < num_cus * min_cu_util_rate" - ), - ), - ) - return result - - -def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl -) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - factory = ( - CustomFactory - if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" - else KernelComponentFactory - ) - - # Only generate fp16/bf16 kernels for now. - # NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) - for dtype in ["fp16", "bf16"]: - d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): - for tile, pipeline in itertools.product( - tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) - ): - if pipeline.tag != "qr_async_vsa": - continue - k = FmhaFwdKernel( - F_idx=1, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= mode == "batch" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) - - -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) - - -def write_blobs( - output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - - -def list_blobs( - file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl -) -> None: - with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 25e3513d2f..350d1803f6 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -277,13 +277,13 @@ struct fmha_jenga_fwd_traits float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); -// sparge jenga -float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); - template float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); -float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); +template +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args); + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); // VSA uses the same traits structure as Jenga; aliases for clarity template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); -float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&); +template +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args); + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp deleted file mode 100644 index 88f3e08204..0000000000 --- a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "jenga_sparge_attention.h" -#include "fmha_fwd_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include - -template -ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "Jenga sparse attention supports fp16/bf16 only."); - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - if(max_seqlen_q == 0) - max_seqlen_q = seqlen_q; - if(max_seqlen_k == 0) - max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - - const ck_tile::index_t shape_seqlen_q = seqlen_q; - const ck_tile::index_t shape_seqlen_k = seqlen_k; - - ck_tile::stream_config stream_config{nullptr, - false, // time_kernel - log_level, - 0, - 1, - false}; - - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); - ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - v_buf.ToDevice(TV.data()); - block_relation_buf.ToDevice(Tblock_relation_onehot.data()); - - const auto init_args = [&](auto& args) { - assert(nhead % nhead_k == 0); - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - }(); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? shape_seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; - }(); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer(); - - args.batch = batch; - args.seqlen_q = shape_seqlen_q; - args.hdim_q = hdim_q; - args.hdim_v = hdim_v; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - - args.stride_q = stride_q; - args.stride_k = stride_k; - args.stride_v = stride_v; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.nhead_stride_v = nhead_stride_v; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.batch_stride_v = batch_stride_v; - - args.o_ptr = o_buf.GetDeviceBuffer(); - - args.seqlen_k = shape_seqlen_k; - args.max_seqlen_q = max_seqlen_q; - - args.scale_s = scale_s; - - args.stride_o = stride_o; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_o = batch_stride_o; - - args.window_size_left = mask.left; - args.window_size_right = mask.right; - args.mask_type = static_cast(mask.type); - }; - - const auto init_traits = [&](auto& traits) { - traits.hdim_q = hdim_q; - traits.hdim_v = hdim_v; - traits.data_type = data_type; - traits.is_v_rowmajor = is_v_rowmajor; - traits.mask_type = mask.type; - }; - - fmha_jenga_fwd_traits fmha_traits; - init_traits(fmha_traits); - - fmha_jenga_fwd_args args; - init_args(args); - - sparge_jenga_fwd(fmha_traits, args, stream_config); - - o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); - - return Y; -} - -template ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); - -template ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h deleted file mode 100644 index 6259fcc73c..0000000000 --- a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" - -template -ck_tile::HostTensor -jenga_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp index fbd18b9ff2..a2df5bac56 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -61,6 +61,57 @@ using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; +// ============================================================================ +// bf16: D=128, kM0=64, kN0=128 +// ============================================================================ + +using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; + +using bmap_bf16_shape = + ck_tile::TileFmhaShape, + ck_tile::sequence<16, 16, 16>, + ck_tile::sequence<4, 1, 1>, + ck_tile::sequence<16, 16, 16>, + true>; + +using bmap_bf16_trait = ck_tile::TileFmhaTraits; + +using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_bf16_mask = ck_tile::GenericAttentionMask; + +using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel; + // ============================================================================ // Dispatch // ============================================================================ @@ -81,8 +132,96 @@ float sparge_blockmap_fwd(sparge_blockmap_traits traits, s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); } + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + using k_ = bmap_bf16_kernel; + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_bf16_d128" << std::flush; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + } + if(s.log_level_ > 0) std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type << ", hdim_q=" << traits.hdim_q << ")" << std::endl; return -1.f; } + +// ============================================================================ +// Oneshot version: launches kernel without timing wrapper +// ============================================================================ + +void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + using k_ = bmap_fp16_kernel; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + return; + } + + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + using k_ = bmap_bf16_kernel; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + return; + } + + std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; +} + +// ============================================================================ +// Combined functions: blockmap + attention timed together via launch_kernel +// ============================================================================ + +float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, + fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { + fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); + }); +} + +float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, + fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { + fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); + }); +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp index 1e7e33248a..6eaeb9ea77 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -91,3 +91,16 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) float sparge_blockmap_fwd(sparge_blockmap_traits traits, sparge_blockmap_args args, const ck_tile::stream_config& stream_config); + +void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +// Combined functions: blockmap + attention with unified timing +float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args, + fmha_jenga_fwd_traits, fmha_jenga_fwd_args, + const ck_tile::stream_config&); + +float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args, + fmha_vsa_fwd_traits, fmha_vsa_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp new file mode 100644 index 0000000000..7c30a10b06 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -0,0 +1,432 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Unified test for Sparge pipeline: blockmap generation + sparse attention (Jenga/VSA). + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "fmha_fwd_trek.hpp" +#include "sparge_blockmap_trek.hpp" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helpers +// ============================================================================ + +template +ck_tile::HostTensor +make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm) +{ + if(i_perm) + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + for(ck_tile::index_t h = 0; h < nhead; ++h) + for(ck_tile::index_t s = 0; s < seqlen; ++s) + for(ck_tile::index_t d = 0; d < hdim; ++d) + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Arg parser +// ============================================================================ +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser + .insert("v", "1", "0:no validation, 1:cpu validation") + .insert("pipeline", "jenga", "attention pipeline: jenga / vsa") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("topk", "0.3", "topk ratio for blockmap (fraction of K-blocks to keep)") + .insert("cdfthreshd", "-1", "CDF threshold for blockmap (overrides topk if >= 0)") + .insert("simthreshd1", "0.6", "similarity threshold for blockmap") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main test +// ============================================================================ +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + std::string pipeline = arg_parser.get_str("pipeline"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + float topk = arg_parser.get_float("topk"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + if(nhead_k < 0) nhead_k = nhead; + if(seqlen_k < 0) seqlen_k = seqlen_q; + if(hdim_v < 0) hdim_v = hdim_q; + + // If cdfthreshd >= 0, use CDF mode; otherwise use topk mode + if(cdfthreshd >= 0.0f) + topk = -1.0f; + + constexpr ck_tile::index_t BLKQ = 64; + constexpr ck_tile::index_t BLKK = 128; + + if(hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<\n" + << "Kernel instances are generated for hdim=128 only.\n"; + return true; + } + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::string prec_str = std::is_same_v ? "fp16" : "bf16"; + std::cout << "[" << pipeline << "|" << prec_str + << "] b=" << batch << " h=" << nhead << " s=" << seqlen_q + << " d=" << hdim_q << " topk=" << topk + << " sim1=" << simthreshd1 << std::flush; + + // ---- allocate host tensors ---- + auto q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + auto k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + auto v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + auto output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + + ck_tile::HostTensor block_map_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor lut_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor valid_block_num_host({batch, nhead, num_q_blocks}); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // ---- device tensors ---- + ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_dev(output_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_map_dev(block_map_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_dev(lut_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_bn_dev(valid_block_num_host.get_element_space_size_in_bytes()); + + q_dev.ToDevice(q_host.data()); + k_dev.ToDevice(k_host.data()); + v_dev.ToDevice(v_host.data()); + o_dev.SetZero(); + block_map_dev.SetZero(); + lut_dev.SetZero(); + valid_bn_dev.SetZero(); + + // ---- strides (BHSD when i_perm=true) ---- + auto q_strides = q_host.get_strides(); + auto k_strides = k_host.get_strides(); + auto v_strides = v_host.get_strides(); + auto o_strides = output_host.get_strides(); + + float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); + + // ---- build blockmap args ---- + sparge_blockmap_traits bmap_traits; + bmap_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + bmap_traits.hdim_q = hdim_q; + + sparge_blockmap_args bmap_args; + bmap_args.q_ptr = q_dev.GetDeviceBuffer(); + bmap_args.k_ptr = k_dev.GetDeviceBuffer(); + bmap_args.batch = batch; + bmap_args.seqlen_q = seqlen_q; + bmap_args.seqlen_k = seqlen_k; + bmap_args.hdim_q = hdim_q; + bmap_args.nhead_q = nhead; + bmap_args.nhead_k = nhead_k; + bmap_args.stride_q = q_strides[i_perm ? 2 : 1]; + bmap_args.stride_k = k_strides[i_perm ? 2 : 1]; + bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + bmap_args.batch_stride_q = q_strides[0]; + bmap_args.batch_stride_k = k_strides[0]; + bmap_args.simthreshd1 = simthreshd1; + bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f; + bmap_args.topk = topk; + bmap_args.scale = scale_s; + bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer(); + bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; + bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr; + + // ---- build attention args ---- + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = nullptr; + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = kname; + stream_cfg.cold_niters_ = warmup; + stream_cfg.nrepeat_ = repeat; + + float avg_ms = -1.0f; + + if(pipeline == "jenga") + { + fmha_jenga_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask_enum::no_mask; + + fmha_jenga_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; + + avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else if(pipeline == "vsa") + { + fmha_vsa_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask_enum::no_mask; + + fmha_vsa_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.lut_ptr = lut_dev.GetDeviceBuffer(); + attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; + + avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else + { + std::cerr << "Unknown pipeline: " << pipeline << " (use jenga or vsa)\n"; + return false; + } + + // ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ---- + std::size_t flop = static_cast(batch) * nhead * + (static_cast(2) * seqlen_q * seqlen_k * hdim_q + + static_cast(2) * seqlen_q * seqlen_k * hdim_v); + float tflops = (avg_ms > 0.f) ? static_cast(flop) / 1.E9f / avg_ms : 0.f; + + if(avg_ms > 0.f) + { + std::cout << std::fixed << ", " << std::setprecision(3) << avg_ms << " ms, " + << std::setprecision(2) << tflops << " TFlops" << std::flush; + } + + // ---- copy results back ---- + o_dev.FromDevice(output_host.data()); + block_map_dev.FromDevice(block_map_host.data()); + + // ---- count active blocks ---- + ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks; + ck_tile::index_t active_blocks = 0; + for(size_t i = 0; i < block_map_host.mData.size(); ++i) + if(block_map_host.mData[i]) + active_blocks++; + float actual_sparsity = 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity + << "(" << active_blocks << "/" << total_blocks << ")" << std::flush; + + // ---- validation ---- + bool pass = true; + if(do_validation) + { + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + + if(diff > atol && rel_diff > rtol) + num_errors++; + } + + pass = (num_errors == 0); + std::cout << ", " << (pass ? "PASS" : "FAIL") + << "(err=" << num_errors << "/" << output_host_bhsd.mData.size() + << " maxdiff=" << max_diff << ")"; + } + + std::cout << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments\n"; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << "\n"; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp deleted file mode 100644 index 590e51db14..0000000000 --- a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -// Demo: Sparge block-map -> Jenga sparse attention - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ck_tile/host.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/reference/reference_blocked_attention.hpp" -#include "ck_tile/core/utility/bit_cast.hpp" - -#include "jenga_sparge_attention.h" -#include "sparge_tool.hpp" - -// ============================================================================ -// Helper Functions -// ============================================================================ - -template -ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, - ck_tile::index_t nhead, - ck_tile::index_t seqlen, - ck_tile::index_t hdim, - bool i_perm) -{ - if(i_perm) - { - return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); - } - return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); -} - -template -ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) -{ - auto lens = tensor.get_lengths(); - ck_tile::index_t batch = lens[0]; - ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; - ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; - ck_tile::index_t hdim = lens[3]; - - ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t s = 0; s < seqlen; ++s) - { - for(ck_tile::index_t d = 0; d < hdim; ++d) - { - out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); - } - } - } - } - return out; -} - -template -auto get_error_tolerance() -{ - double rtol = 1e-2; - double atol = 4e-2; - if constexpr(std::is_same_v) - { - atol = 2e-1; - rtol = 2e-1; - } - return ck_tile::make_tuple(rtol, atol); -} - -template -float to_float_for_compare(T value) -{ - return static_cast(value); -} - -template <> -float to_float_for_compare(ck_tile::bf16_t value) -{ -#if CK_TILE_USE_CUSTOM_DATA_TYPE - return static_cast(value); -#else - return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); -#endif -} - -// ============================================================================ -// Command line argument parser -// ============================================================================ - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") - .insert("b", "1", "batch size") - .insert("h", "4", "num of head for q") - .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") - .insert("s", "4096", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q, k") - .insert("d_v", "-1", "head dim for v, -1 means equal to d") - .insert("prec", "fp16", "data type: fp16/bf16") - .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") - .insert("operm", "1", "permute output") - .insert("seed", "42", "random seed") - .insert("warmup", "5", "warmup iterations") - .insert("repeat", "20", "benchmark iterations") - .insert("kname", "0", "print kernel name") - // Sparge-specific - .insert("blkq", "64", "Sparge BLKQ") - .insert("blkk", "128", "Sparge BLKK") - .insert("simthreshd1", "0.6", "Sparge sim threshold") - .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") - .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -// ============================================================================ -// Main Test Function -// ============================================================================ - -template -bool run_test(const ck_tile::ArgParser& arg_parser) -{ - int do_validation = arg_parser.get_int("v"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int kname = arg_parser.get_int("kname"); - - // Sparge params - ck_tile::index_t blkq = arg_parser.get_int("blkq"); - ck_tile::index_t blkk = arg_parser.get_int("blkk"); - float simthreshd1 = arg_parser.get_float("simthreshd1"); - float cdfthreshd = arg_parser.get_float("cdfthreshd"); - float topk = arg_parser.get_float("topk"); - - if(nhead_k < 0) - nhead_k = nhead; - if(seqlen_k < 0) - seqlen_k = seqlen_q; - if(hdim_v < 0) - hdim_v = hdim_q; - - if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) - { - std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, " - "hdim_q=128, hdim_v=128 only." - << std::endl; - std::cout << "TEST SKIPPED" << std::endl; - return true; - } - - ck_tile::index_t BLKQ = blkq; - ck_tile::index_t BLKK = blkk; - - ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; - ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; - - std::cout << "============================================================" << std::endl; - std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl; - std::cout << "============================================================" << std::endl; - std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k - << std::endl; - std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; - std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; - std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; - std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks - << std::endl; - std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd - << ", topk=" << topk << ")" << std::endl; - std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; - - // Create host tensors - ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); - ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); - ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); - ck_tile::HostTensor output_host = - o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) - : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); - ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); - - std::cout << "\nInitializing tensors..." << std::endl; - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - - // Build block map using Sparge tool - std::cout << "Building Sparge block map..." << std::endl; - sparge::SpargeParams p; - p.BLKQ = static_cast(BLKQ); - p.BLKK = static_cast(BLKK); - p.simthreshd1 = simthreshd1; - p.cdfthreshd = cdfthreshd; - p.topk = topk; - p.i_perm = i_perm; - - ck_tile::HostTensor block_relation_onehot = - sparge::build_block_map_meansim(q_host, k_host, p); - - // Print actual sparsity - std::size_t total_blocks = 0; - std::size_t active_blocks = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - total_blocks++; - if(block_relation_onehot(b, h, qb, kb) != 0) - active_blocks++; - } - } - } - } - float actual_sparsity = - 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" - << total_blocks << " blocks active)" << std::endl; - - std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; - - try - { - if(kname) - { - jenga_sparge_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 1); - } - - for(int i = 0; i < warmup; ++i) - { - jenga_sparge_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } - - [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); - auto start = std::chrono::high_resolution_clock::now(); - - for(int i = 0; i < repeat; ++i) - { - jenga_sparge_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } - - [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); - auto end = std::chrono::high_resolution_clock::now(); - double avg_time_ms = - std::chrono::duration(end - start).count() / repeat; - - std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<" - << std::endl; - } - catch(const std::exception& e) - { - std::cerr << "Error during kernel execution: " << e.what() << std::endl; - return false; - } - - bool pass = true; - if(do_validation) - { - std::cout << "\n--- Performing CPU validation ---" << std::endl; - float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - - std::cout << "Computing reference output..." << std::endl; - auto q_ref = to_bhsd(q_host, i_perm); - auto k_ref = to_bhsd(k_host, i_perm); - auto v_ref = to_bhsd(v_host, i_perm); - - ck_tile::reference_blocked_attention( - q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); - - auto [rtol, atol] = get_error_tolerance(); - - float max_diff = 0.0f; - float max_rel_diff = 0.0f; - std::size_t num_errors = 0; - - auto output_host_bhsd = to_bhsd(output_host, o_perm); - for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) - { - float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); - float ref_val = to_float_for_compare(output_ref.mData[i]); - float diff = std::abs(gpu_val - ref_val); - float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; - - max_diff = std::max(max_diff, diff); - max_rel_diff = std::max(max_rel_diff, rel_diff); - - if(diff > atol && rel_diff > rtol) - { - num_errors++; - if(num_errors <= 5) - { - std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val - << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; - } - } - } - - std::cout << "\nValidation results:" << std::endl; - std::cout << " Max absolute difference: " << max_diff << std::endl; - std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Number of mismatches: " << num_errors << " / " - << output_host_bhsd.mData.size() << std::endl; - - if(num_errors == 0) - { - std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; - } - else - { - std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; - pass = false; - } - } - - std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; - return pass; -} - -// ============================================================================ -// Main -// ============================================================================ - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - { - std::cerr << "Failed to parse arguments" << std::endl; - return -1; - } - - std::string prec = arg_parser.get_str("prec"); - - bool test_result = false; - if(prec == "fp16") - { - test_result = run_test(arg_parser); - } - else if(prec == "bf16") - { - test_result = run_test(arg_parser); - } - else - { - std::cerr << "Unsupported precision: " << prec << std::endl; - return -1; - } - - return test_result ? 0 : -1; -} diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp deleted file mode 100644 index 572b708f9e..0000000000 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ /dev/null @@ -1,597 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device) - -#include -#include -#include -#include "ck_tile/host.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/reference/reference_blocked_attention.hpp" -#include "ck_tile/core/utility/bit_cast.hpp" - -#include "sparge_blockmap_trek.hpp" -#include "fmha_fwd_trek.hpp" -#include "sparge_tool.hpp" - -// ============================================================================ -// Helper Functions -// ============================================================================ - -template -ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, - ck_tile::index_t nhead, - ck_tile::index_t seqlen, - ck_tile::index_t hdim, - bool i_perm) -{ - if(i_perm) - { - return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); - } - return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); -} - -template -ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) -{ - auto lens = tensor.get_lengths(); - ck_tile::index_t batch = lens[0]; - ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; - ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; - ck_tile::index_t hdim = lens[3]; - - ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t s = 0; s < seqlen; ++s) - { - for(ck_tile::index_t d = 0; d < hdim; ++d) - { - out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); - } - } - } - } - return out; -} - -template -auto get_error_tolerance() -{ - double rtol = 1e-2; - double atol = 4e-2; - if constexpr(std::is_same_v) - { - atol = 2e-1; - rtol = 2e-1; - } - return ck_tile::make_tuple(rtol, atol); -} - -template -float to_float_for_compare(T value) -{ - return static_cast(value); -} - -template <> -float to_float_for_compare(ck_tile::bf16_t value) -{ -#if CK_TILE_USE_CUSTOM_DATA_TYPE - return static_cast(value); -#else - return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); -#endif -} - -// ============================================================================ -// Command line argument parser -// ============================================================================ - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") - .insert("b", "1", "batch size") - .insert("h", "4", "num of head for q") - .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") - .insert("s", "4096", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q, k") - .insert("d_v", "-1", "head dim for v, -1 means equal to d") - .insert("prec", "fp16", "data type: fp16/bf16") - .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") - .insert("operm", "1", "permute output") - .insert("seed", "42", "random seed") - .insert("warmup", "5", "warmup iterations") - .insert("repeat", "20", "benchmark iterations") - .insert("kname", "0", "print kernel name") - // Sparge-specific - .insert("blkq", "64", "Sparge BLKQ") - .insert("blkk", "128", "Sparge BLKK") - .insert("simthreshd1", "0.6", "Sparge sim threshold") - .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") - .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -// ============================================================================ -// Main Test Function -// ============================================================================ - -template -bool run_test(const ck_tile::ArgParser& arg_parser) -{ - int do_validation = arg_parser.get_int("v"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int kname = arg_parser.get_int("kname"); - - // Sparge params - ck_tile::index_t blkq = arg_parser.get_int("blkq"); - ck_tile::index_t blkk = arg_parser.get_int("blkk"); - float simthreshd1 = arg_parser.get_float("simthreshd1"); - float cdfthreshd = arg_parser.get_float("cdfthreshd"); - float topk = arg_parser.get_float("topk"); - - if(nhead_k < 0) - nhead_k = nhead; - if(seqlen_k < 0) - seqlen_k = seqlen_q; - if(hdim_v < 0) - hdim_v = hdim_q; - - if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) - { - std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, " - "hdim_q=128, hdim_v=128 only." - << std::endl; - std::cout << "TEST SKIPPED" << std::endl; - return true; - } - - ck_tile::index_t BLKQ = blkq; - ck_tile::index_t BLKK = blkk; - - ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; - ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; - - std::cout << "============================================================" << std::endl; - std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl; - std::cout << "============================================================" << std::endl; - std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k - << std::endl; - std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; - std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; - std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; - std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks - << std::endl; - std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd - << ", topk=" << topk << ")" << std::endl; - std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; - - // Create host tensors and fill with random data - ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); - ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); - ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); - ck_tile::HostTensor output_host = - o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) - : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); - - std::cout << "\nInitializing tensors..." << std::endl; - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); - ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - - // ================================================================== - // Allocate device memory once, HtoD once - // ================================================================== - ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes()); - - q_buf.ToDevice(q_host.data()); - k_buf.ToDevice(k_host.data()); - v_buf.ToDevice(v_host.data()); - - const std::size_t bmap_bytes = - static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t); - const std::size_t lut_bytes = - static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t); - const std::size_t valid_bytes = - static_cast(batch) * nhead * num_q_blocks * sizeof(int32_t); - - ck_tile::DeviceMem bmap_buf(bmap_bytes); - ck_tile::DeviceMem lut_buf(lut_bytes); - ck_tile::DeviceMem valid_buf(valid_bytes); - bmap_buf.SetZero(); - lut_buf.SetZero(); - valid_buf.SetZero(); - - // ================================================================== - // Common stride calculations - // ================================================================== - assert(nhead % nhead_k == 0); - const float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); - - const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q; - const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; - const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v; - const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v; - const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v; - const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v; - const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q; - const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k; - const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v; - - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - data_type = "bf16"; - - std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - - // ================================================================== - // GPU: Build block map + VSA LUT (always run, device-only) - // ================================================================== - std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; - { - sparge_blockmap_args args; - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.batch = batch; - args.seqlen_q = seqlen_q; - args.seqlen_k = seqlen_k; - args.hdim_q = hdim_q; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - args.stride_q = stride_q; - args.stride_k = stride_k; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.simthreshd1 = simthreshd1; - args.cdfthreshd = cdfthreshd; - args.topk = topk; - args.scale = scale_s; - args.block_map_ptr = bmap_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); - - sparge_blockmap_traits traits; - traits.data_type = data_type; - traits.hdim_q = hdim_q; - - sparge_blockmap_fwd(traits, args, ck_tile::stream_config{}); - } - - // ================================================================== - // VSA sparse attention kernel (always run, LUT stays on device) - // ================================================================== - std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; - - fmha_vsa_fwd_args fmha_args; - fmha_args.q_ptr = q_buf.GetDeviceBuffer(); - fmha_args.k_ptr = k_buf.GetDeviceBuffer(); - fmha_args.v_ptr = v_buf.GetDeviceBuffer(); - fmha_args.lut_ptr = lut_buf.GetDeviceBuffer(); - fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); - fmha_args.o_ptr = o_buf.GetDeviceBuffer(); - fmha_args.batch = batch; - fmha_args.seqlen_q = seqlen_q; - fmha_args.seqlen_k = seqlen_k; - fmha_args.max_seqlen_q = seqlen_q; - fmha_args.hdim_q = hdim_q; - fmha_args.hdim_v = hdim_v; - fmha_args.nhead_q = nhead; - fmha_args.nhead_k = nhead_k; - fmha_args.scale_s = scale_s; - fmha_args.stride_q = stride_q; - fmha_args.stride_k = stride_k; - fmha_args.stride_v = stride_v; - fmha_args.stride_o = stride_o; - fmha_args.nhead_stride_q = nhead_stride_q; - fmha_args.nhead_stride_k = nhead_stride_k; - fmha_args.nhead_stride_v = nhead_stride_v; - fmha_args.nhead_stride_o = nhead_stride_o; - fmha_args.batch_stride_q = batch_stride_q; - fmha_args.batch_stride_k = batch_stride_k; - fmha_args.batch_stride_v = batch_stride_v; - fmha_args.batch_stride_o = batch_stride_o; - fmha_args.window_size_left = mask.left; - fmha_args.window_size_right = mask.right; - fmha_args.mask_type = static_cast(mask.type); - - fmha_vsa_fwd_traits fmha_traits; - fmha_traits.hdim_q = hdim_q; - fmha_traits.hdim_v = hdim_v; - fmha_traits.data_type = data_type; - fmha_traits.is_v_rowmajor = true; - fmha_traits.mask_type = mask.type; - - ck_tile::stream_config stream_config{nullptr, - true, - /* log_level = */ kname ? 1 : 0, - warmup, - repeat, - false}; - - float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config); - - std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" - << std::endl; - - // DtoH: attention output (always needed) - o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes()); - - // DtoH: block_map (needed for sparsity stats and validation) - ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); - bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes); - - // ================================================================== - // Sparsity statistics (pure CPU, reads block_map HostTensor) - // ================================================================== - std::size_t total_blocks = 0; - std::size_t active_blocks = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - total_blocks++; - if(block_map_gpu(b, h, qb, kb) != 0) - active_blocks++; - } - } - } - } - float actual_sparsity = - 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" - << total_blocks << " blocks active)" << std::endl; - - // ================================================================== - // Validation (only when -v=1) - // ================================================================== - bool pass = true; - if(do_validation) - { - std::cout << "\n--- Performing CPU validation ---" << std::endl; - - // CPU golden: block map + VSA LUT - std::cout << "Building Sparge block map (CPU golden)..." << std::endl; - sparge::SpargeParams p; - p.BLKQ = static_cast(BLKQ); - p.BLKK = static_cast(BLKK); - p.simthreshd1 = simthreshd1; - p.cdfthreshd = cdfthreshd; - p.topk = topk; - p.i_perm = i_perm; - - ck_tile::HostTensor block_relation_onehot = - sparge::build_block_map_meansim(q_host, k_host, p); - - std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; - auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); - - // DtoH: LUT + valid_block_num (only for validation) - sparge::VSALut vsa_lut_gpu{ - ck_tile::HostTensor({batch, nhead, num_q_blocks, num_k_blocks}), - ck_tile::HostTensor({batch, nhead, num_q_blocks}), - }; - lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes); - valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes); - - // Validate block map - std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; - { - std::size_t bmap_mismatches = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb)) - { - bmap_mismatches++; - if(bmap_mismatches <= 10) - { - std::cout - << " block_map mismatch at [" << b << "," << h << "," << qb - << "," << kb << "]: GPU=" - << static_cast(block_map_gpu(b, h, qb, kb)) << " CPU=" - << static_cast(block_relation_onehot(b, h, qb, kb)) - << std::endl; - } - } - } - } - } - } - std::cout << " Block map mismatches: " << bmap_mismatches << " / " - << (batch * nhead * num_q_blocks * num_k_blocks) << std::endl; - if(bmap_mismatches > 0) - { - std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl; - pass = false; - } - else - { - std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl; - } - } - - // Validate VSA LUT - std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl; - { - std::size_t lut_mismatches = 0; - std::size_t valid_mismatches = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - if(vsa_lut_gpu.valid_block_num(b, h, qb) != - vsa_lut_cpu.valid_block_num(b, h, qb)) - { - valid_mismatches++; - if(valid_mismatches <= 5) - { - std::cout << " valid_block_num mismatch at [" << b << "," << h - << "," << qb - << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) - << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) - << std::endl; - } - } - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb)) - { - lut_mismatches++; - if(lut_mismatches <= 10) - { - std::cout - << " LUT mismatch at [" << b << "," << h << "," << qb - << "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) - << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl; - } - } - } - } - } - } - std::cout << " LUT mismatches: " << lut_mismatches << std::endl; - std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl; - if(lut_mismatches == 0 && valid_mismatches == 0) - { - std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl; - } - else - { - std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl; - pass = false; - } - } - - // Validate attention output - float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - - std::cout << "\nComputing reference attention output..." << std::endl; - auto q_ref = to_bhsd(q_host, i_perm); - auto k_ref = to_bhsd(k_host, i_perm); - auto v_ref = to_bhsd(v_host, i_perm); - - ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); - ck_tile::reference_blocked_attention( - q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); - - auto [rtol, atol] = get_error_tolerance(); - - float max_diff = 0.0f; - float max_rel_diff = 0.0f; - std::size_t num_errors = 0; - - auto output_host_bhsd = to_bhsd(output_host, o_perm); - for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) - { - float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); - float ref_val = to_float_for_compare(output_ref.mData[i]); - float diff = std::abs(gpu_val - ref_val); - float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; - - max_diff = std::max(max_diff, diff); - max_rel_diff = std::max(max_rel_diff, rel_diff); - - if(diff > atol && rel_diff > rtol) - { - num_errors++; - if(num_errors <= 5) - { - std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val - << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; - } - } - } - - std::cout << "\nAttention validation results:" << std::endl; - std::cout << " Max absolute difference: " << max_diff << std::endl; - std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Number of mismatches: " << num_errors << " / " - << output_host_bhsd.mData.size() << std::endl; - - if(num_errors == 0) - { - std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; - } - else - { - std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; - pass = false; - } - } - - std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; - return pass; -} - -// ============================================================================ -// Main -// ============================================================================ - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - { - std::cerr << "Failed to parse arguments" << std::endl; - return -1; - } - - std::string prec = arg_parser.get_str("prec"); - - bool test_result = false; - if(prec == "fp16") - { - test_result = run_test(arg_parser); - } - else if(prec == "bf16") - { - test_result = run_test(arg_parser); - } - else - { - std::cerr << "Unsupported precision: " << prec << std::endl; - return -1; - } - - return test_result ? 0 : -1; -} diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 67936c4353..9fe8b365b0 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -318,26 +318,26 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga { if(!block_relation_onehot[i_total_loops]) { - i_total_loops++; - if(i_total_loops < num_total_loop) - { - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - - if(block_relation_onehot[i_total_loops]) - { - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); - } - move_tile_window(k_dram_window, {0, kK0}); - move_tile_window(v_dram_window, {0, kN0}); - continue; - } - break; + // scan-ahead: find the next active block in one shot + index_t next = i_total_loops + 1; + while(next < num_total_loop && !block_relation_onehot[next]) + next++; + if(next >= num_total_loop) + break; + const index_t delta = next - i_total_loops; + i_total_loops = next; + // jump K/V windows to the next active block + move_tile_window(k_dram_block_window, {kN0 * delta, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + move_tile_window(v_dram_window, {0, kN0 * delta}); + // immediately prefetch the active K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + continue; } // STAGE 1, QK gemm