From 2ec8ebff31f0c18600be604e9adb04a8fa8fa91a Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Wed, 10 Sep 2025 09:06:14 +0600 Subject: [PATCH] [CK_TILE] Add gtests for FMHA (#2744) * Improve random number generation * use different seed for each input (Q, K, V...); * use deterministic generation of: * seqstart_q/k (for group mode); * block_table (for paged-kvcahe); * cache_batch_idx (for kvcache); * Extract arg_parser-related code from run functions to use them as tests * Split examples into main programs and fmha runners, build instances separately * Add dummy tests that use instances and runners * Fix a missed corner case of f32->f8 conversion When value if < min f8 denormal but > min f8 denormal / 2, it must be rounded to min f8 denormal (i.e. 0b1), not to 0. * Fix incorrect fp8 scales for P and O in validation code DataTypeConfig was incorrectly compared with fp8_t. * Add host generation of dropout random values and use it for validation Previously host validation (reference_batched_dropout) used random numbers generated by BlockDropout of the kernel, meaning that incorrect generation on device (bad distribution, repeated numbers, too many zeros, etc.) would not trigger any validation errors. * Implement tests from smoke_test_bwd.sh * Return result as enum to distinguish failure and missing instance * Add tests for bwd features: bias, alibi, dropout * Implement tests from smoke_test_fwd.sh * Pass seqlen_q/k as vectors to fwd and bwd runners * Add tests for fwd features: bias, alibi, dropout * Add tests for pagedkv and splitkv * Fix conditions when to use splitkv and pagedkv kernels splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size). In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1. In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance. * Add tests for appendkv * Use is_v_rowmajor = true because there are no instances with column layout anymore * Split public and private compile options for instances Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API. * Improve parsing validation in bias and mask * Pass bias as string for consistency with mask * Catch parsing and other exceptions * Add bwd test for deterministic flag * Initialize fp8 tensors (-init=ufq) similarly to uf * Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null seqlen_k cannot be used to determine padding when seqlen_k_ptr is provided. The actual seqlen_k is taken from seqlen_k_ptr[b]. Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr may contain arbitrary values. In the example or tests this produces incorrect results with appendkv (for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8). * Fix use_pagedkv value when kvcache = true but page_block_size = 0 In this case block_table_ptr is nullptr which is accessed in the kernel. * Clean up bwd tests * Unify fwd tests for f16/bf16 and fp8 * Use better explicit instantiation declaration for fmha_bwd<2> * Use the same seed for all tests, allow to override it with env variable * Undo clang-format of one irrelevant file For some reason my local clang-format-18 and the one in CI work differently. * Do not build instances and tests on unsupported archs * Build instance libraries as OBJECT library * CI: Enable sccache for HIP There are source files with LANGUAGE HIP, they need -DCMAKE_HIP_COMPILER_LAUNCHER=sccache * Add tests to REGRESSION_TESTS * Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0 The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf. * Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options The instances don't actually depend on them, only examples and tests do. Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS without recompiling instances that are already in ccache. * Fix formatting and names [ROCm/composable_kernel commit: ec006bb8e008caaf5cb95c1ca5037dc6ac026bde] --- Jenkinsfile | 2 +- README.md | 2 +- example/ck_tile/01_fmha/CMakeLists.txt | 116 ++- example/ck_tile/01_fmha/bias.hpp | 54 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 6 +- .../codegen/ops/fmha_pagedkv_prefill.py | 6 +- example/ck_tile/01_fmha/example_fmha_bwd.cpp | 183 +++++ example/ck_tile/01_fmha/example_fmha_fwd.cpp | 253 ++++++ .../{fmha_bwd.cpp => fmha_bwd_runner.hpp} | 404 ++++----- .../{fmha_fwd.cpp => fmha_fwd_runner.hpp} | 774 ++++++++---------- example/ck_tile/01_fmha/mask.hpp | 77 +- example/ck_tile/01_fmha/utils.hpp | 140 ++-- include/ck_tile/core/numeric/float8.hpp | 4 +- include/ck_tile/host.hpp | 1 + .../reference_batched_dropout_randval.hpp | 70 ++ include/ck_tile/utility/json_dump.hpp | 4 +- test/CMakeLists.txt | 5 + test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/data_type/test_fp8.cpp | 16 +- test/ck_tile/fmha/CMakeLists.txt | 31 + test/ck_tile/fmha/test_fmha_bwd.inc | 344 ++++++++ test/ck_tile/fmha/test_fmha_bwd_bf16.cpp | 21 + test/ck_tile/fmha/test_fmha_bwd_fp16.cpp | 21 + test/ck_tile/fmha/test_fmha_fwd.inc | 628 ++++++++++++++ test/ck_tile/fmha/test_fmha_fwd_bf16.cpp | 44 + test/ck_tile/fmha/test_fmha_fwd_fp16.cpp | 44 + test/ck_tile/fmha/test_fmha_fwd_fp8.cpp | 43 + 27 files changed, 2429 insertions(+), 865 deletions(-) create mode 100644 example/ck_tile/01_fmha/example_fmha_bwd.cpp create mode 100644 example/ck_tile/01_fmha/example_fmha_fwd.cpp rename example/ck_tile/01_fmha/{fmha_bwd.cpp => fmha_bwd_runner.hpp} (76%) rename example/ck_tile/01_fmha/{fmha_fwd.cpp => fmha_fwd_runner.hpp} (72%) create mode 100644 include/ck_tile/host/reference/reference_batched_dropout_randval.hpp create mode 100644 test/ck_tile/fmha/CMakeLists.txt create mode 100644 test/ck_tile/fmha/test_fmha_bwd.inc create mode 100644 test/ck_tile/fmha/test_fmha_bwd_bf16.cpp create mode 100644 test/ck_tile/fmha/test_fmha_bwd_fp16.cpp create mode 100644 test/ck_tile/fmha/test_fmha_fwd.inc create mode 100644 test/ck_tile/fmha/test_fmha_fwd_bf16.cpp create mode 100644 test/ck_tile/fmha/test_fmha_fwd_fp16.cpp create mode 100644 test/ck_tile/fmha/test_fmha_fwd_fp8.cpp diff --git a/Jenkinsfile b/Jenkinsfile index c3436ec3b8..2b9ea200f0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -321,7 +321,7 @@ def cmake_build(Map conf=[:]){ ${redis_pre_setup_cmd} """) sh cmd1 - setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache " + setup_args + setup_args = " -DCMAKE_HIP_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache " + setup_args } catch(Exception err){ echo "could not connect to redis server: ${err.getMessage()}. will not use sccache." diff --git a/README.md b/README.md index 459e17d9a3..32688b6574 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ hours to 1-2 minutes. In order to invoke sccache, you need to run: then add the following flags to the cmake command line: ```bash - -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache + -DCMAKE_HIP_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache ``` You may need to clean up the build folder and repeat the cmake and make steps in order to take diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 5f495c76d8..b1e2373657 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,7 +1,19 @@ +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# Currently only gfx9 archs are supported by FMHA +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +if(NOT INST_TARGETS) + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + # validate user-specified fmha_fwd API list set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +if(BUILD_TESTING) + # Build instances of all APIs for tests + set(FMHA_FWD_ENABLE_APIS "all") +endif() if(FMHA_FWD_ENABLE_APIS STREQUAL "all") set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) endif() @@ -77,72 +89,100 @@ add_custom_command( DEPENDS ${CODE_GEN_SCRIPTS} ) -set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") -add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") +set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") -set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" -message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") -add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) -target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) +message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") +add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}") +add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS}) +set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS) +set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS) +set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS) +set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) + +# Allow comparing floating points directly in order to check sentinel values +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) - set(FMHA_FWD_FAST_EXP2 true) + set(FMHA_FWD_FAST_EXP2 ON) endif() -set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) -set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) - -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -# ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() -list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) -# conditionally enable call to the fwd_splitkv API in fmha_fwd example +# conditionally enable call to the fwd_splitkv API in fmha_fwd example and tests if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) endif() -# conditionally enable call to the fwd_appendkv API in fmha_fwd example +# conditionally enable call to the fwd_appendkv API in fmha_fwd example and tests if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() -# conditionally enable call to the pagedkv_prefill API in fmha_fwd example +# conditionally enable call to the pagedkv_prefill API in fmha_fwd example and tests if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) endif() # conditionally specify the use of OCP_FP8 if(CK_USE_OCP_FP8) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -# Allow comparing floating points directly in order to check sentinel values -list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) -list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) +target_compile_options(${FMHA_FWD_INSTANCES} + PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} + INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) +target_compile_options(${FMHA_BWD_INSTANCES} + PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS} + INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS}) -target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) -target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) +set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") +set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") + +message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") +# not using add_example_executable() to add this target, since we don't want this to be included in +# "make all/install/check" +add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) +target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") +# not using add_example_executable() to add this target, since we don't want this to be included in +# "make all/install/check" +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) +target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) +target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) # add fmha_fwd_v3 example set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp index f9dc656f63..c07232a13a 100644 --- a/example/ck_tile/01_fmha/bias.hpp +++ b/example/ck_tile/01_fmha/bias.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -63,31 +63,45 @@ struct bias_info static bias_info decode(std::string str) { bias_info info{bias_enum::no_bias, 0}; - if(str == "0" || str == "n") + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "e" || t == "elementwise") + { + info.type = bias_enum::elementwise_bias; + info.rank_info = std::stoi(v); + if(info.rank_info < 0 || info.rank_info > 2) + throw std::invalid_argument("invalid bias rank: " + str); + } + else if(t == "a" || t == "alibi") + { + info.type = bias_enum::alibi; + info.rank_info = std::stoi(v); + if(info.rank_info < 0 || info.rank_info > 1) + throw std::invalid_argument("invalid bias rank: " + str); + } + else + { + throw std::invalid_argument("invalid bias value: " + str); + } + } + else if(str == "0" || str == "n") { info.type = bias_enum::no_bias; } - else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || - str.compare(0, 11, "elementwise") == 0) + else if(str == "1" || str == "e" || str == "elementwise") { - info.type = bias_enum::elementwise_bias; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string e = str.substr(found_0 + 1); - info.rank_info = atoi(e.c_str()); - } + info.type = bias_enum::elementwise_bias; } - else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || - str.compare(0, 5, "alibi") == 0) + else if(str == "2" || str == "a" || str == "alibi") { - info.type = bias_enum::alibi; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string e = str.substr(found_0 + 1); - info.rank_info = atoi(e.c_str()); - } + info.type = bias_enum::alibi; + } + else + { + throw std::invalid_argument("invalid bias value: " + str); } return info; } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 1dd8f0e3c6..3b48b3d005 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy @@ -347,8 +347,8 @@ class FmhaFwdSplitKVApiTrait: if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' else: assert False @property diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index e468e82ed5..7b93e9654c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy @@ -189,8 +189,8 @@ class FmhaFwdApiTrait: if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr_pagedkv', 'qs']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' else: assert False @property diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp new file mode 100644 index 0000000000..e0e1fba668 --- /dev/null +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "fmha_bwd.hpp" +#include "fmha_bwd_runner.hpp" + +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "whether do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch " + "(group mode)") + .insert("s_k", + "-1", + "seqlen_k, -1 means equal to s\n" + "also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch " + "(group mode)") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("dbias", "0", "output bias gradient or not") + .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", + "uf", + "init method:\n ui or 0 - uniform random int\n uf or 1 - uniform random float" + "\n tf or 2 - trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for dropout random number generator") + .insert("drop_offset", "0", "offset for dropout random number generator") + .insert( + "drop_prefs", + "0", + "whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("deterministic", + "0", + "if set to 1 will use multi-buffer reduction strategy for dq, atomic operation " + "will not be used") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "fmha_bwd.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +auto run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + mode_enum mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + auto seqlen_qs = arg_parser.get_int_vec("s"); + auto seqlen_ks = arg_parser.get_int_vec("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + float scale = arg_parser.get_float("scale"); + std::string bias_str = arg_parser.get_str("bias"); + bool use_dbias = arg_parser.get_bool("dbias"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + std::string mask_str = arg_parser.get_str("mask"); + bool deterministic = arg_parser.get_bool("deterministic"); + std::string init_method = arg_parser.get_str("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0), + arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_str("timer") == std::string("gpu")}; + + auto json = arg_parser.get_int("json") == 1 + ? std::optional{arg_parser.get_str("jsonfile")} + : std::nullopt; + + return fmha_bwd_run(mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + hdim_q, + hdim_v, + i_perm, + o_perm, + scale, + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + deterministic, + init_method, + seed, + do_validation, + stream_config, + json); +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) == bwd_result::success ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) == bwd_result::success ? 0 : -2; + } + std::cerr << "Unsupported precision: " << data_type << std::endl; + return -1; + } + catch(const std::invalid_argument& e) + { + std::cerr << "Invalid argument: " << e.what() << std::endl; + return -1; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << std::endl; + return -2; + } +} diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp new file mode 100644 index 0000000000..c3bbb7a558 --- /dev/null +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "fmha_fwd.hpp" +#include "fmha_fwd_runner.hpp" + +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch " + "(group mode)") + .insert("s_k", + "-1", + "seqlen_k (including new key/value), -1 means equal to s\n" + "also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch " + "(group mode)") + .insert("s_knew", + "0", + "seqlen_k for new key/value, 0 means not to use this at all; " + "-1 to choose s_knew in [1, s] randomly.") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 batches, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed, same as xformer kv_padding,\n" + "must be greater than or equal to s_k") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale_s", + "0", + "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" + "note when squant=1, this value will be modified by range_q/k") + .insert("logits_soft_cap", "0", "attention logits soft capping value.") + .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") + .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") + .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") + .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") + .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") + .insert("squant", + "auto", + "if using static quantization fusion or not. auto: fp8 will default use squant, " + "other will not\n" + "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " + "P and O.\n" + "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " + "range_p, range_o") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") + .insert("lse", "0", "0 not store lse, 1 store lse") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", + "uf", + "init method:\n ui or 0 - uniform random int\n ni - normalized random int" + "\n uf or 1 - uniform random float\n nf - normalized random float" + "\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for dropout random number generator") + .insert("drop_offset", "0", "offset for dropout random number generator") + .insert( + "drop_prefs", + "0", + "whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert( + "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") + .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") + .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe") + .insert("cache_batch_idx", "0", "whether to use index map to the kvcache") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +auto run(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + mode_enum mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + auto seqlen_qs = arg_parser.get_int_vec("s"); + auto seqlen_ks = arg_parser.get_int_vec("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); + auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); + ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + float scale_s = arg_parser.get_float("scale_s"); + float logits_soft_cap = arg_parser.get_float("logits_soft_cap"); + bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r"; + bool lse = arg_parser.get_bool("lse"); + ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); + bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); + std::string bias_str = arg_parser.get_str("bias"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + std::string mask_str = arg_parser.get_str("mask"); + float range_q = arg_parser.get_float("range_q"); + float range_k = arg_parser.get_float("range_k"); + float range_v = arg_parser.get_float("range_v"); + float range_p = arg_parser.get_float("range_p"); + float range_o = arg_parser.get_float("range_o"); + bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); + ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); + std::string init_method = arg_parser.get_str("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + + bool squant = [&]() { + if(arg_parser.get_str("squant") == "auto") + return std::is_same_v; + else + return arg_parser.get_bool("squant"); + }(); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0), + arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_str("timer") == std::string("gpu")}; + + auto json = arg_parser.get_int("json") == 1 + ? std::optional{arg_parser.get_str("jsonfile")} + : std::nullopt; + + return fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + hdim_q, + hdim_v, + seqlen_knew, + seqlen_kpads, + rotary_dim, + i_perm, + o_perm, + scale_s, + logits_soft_cap, + is_v_rowmajor, + lse, + page_block_size, + use_cache_batch_idx, + bias_str, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + range_q, + range_k, + range_v, + range_p, + range_o, + squant, + is_rotary_interleaved, + num_splits, + init_method, + seed, + do_validation, + stream_config, + json); +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp8") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + std::cerr << "Unsupported precision: " << data_type << std::endl; + return -1; + } + catch(const std::invalid_argument& e) + { + std::cerr << "Invalid argument: " << e.what() << std::endl; + return -1; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << std::endl; + return -2; + } +} diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp similarity index 76% rename from example/ck_tile/01_fmha/fmha_bwd.cpp rename to example/ck_tile/01_fmha/fmha_bwd_runner.hpp index cc2663e751..3a5b5b4603 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "fmha_bwd.hpp" +#pragma once + #include "ck_tile/host.hpp" -#include "mask.hpp" +#include "fmha_bwd.hpp" #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" @@ -17,91 +18,13 @@ #include #include -template -std::ostream& operator<<(std::ostream& os, const std::vector& v) +enum class bwd_result { - using size_type = typename std::vector::size_type; - - os << "["; - for(size_type idx = 0; idx < v.size(); ++idx) - { - if(0 < idx) - { - os << ", "; - } - os << v[idx]; - } - return os << "]"; -} - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "weather do CPU validation or not") - .insert("mode", "0", "kernel mode. 0:batch, 1:group") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "-1", - "num of head, for k/v, -1 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert("s", - "3328", - "seqlen_q. if group-mode, means the average value of seqlen_q\n" - "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q, k") - .insert("d_v", "-1", "head dim for v, -1 means equal to d") - .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") - .insert("iperm", - "1", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "1", "permute output") - .insert("bias", - "n", - "n or 0, no bias\n" - "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" - "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("dbias", "0", "output bias gradient or not") - .insert("prec", "fp16", "data type. fp16 or bf16") - .insert("mask", - "0", - "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" - "'t', top-left causal mask, 'b', bottom-r causal mask\n" - "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" - "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" - "'xt:window_size', xformer style masking from top-left, window_size negative is " - "causal, positive is swa\n" - "'xb:window_size', xformer style masking from bottom-r, window_size negative is " - "causal, positive is swa\n" - "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " - "now)") - .insert("kname", "0", "if set to 1 will print kernel name") - .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") - .insert("seed", - "11939", - "random seed used for initializing input tensors. 0 for " - "non-deterministic seed") - .insert("p_drop", "0", "0~1 probability of dropout") - .insert("drop_seed", "1", "seed for random number generator") - .insert("drop_offset", "0", "offset for random number generator") - .insert("drop_prefs", - "0", - "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel") - .insert("deterministic", - "0", - "if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion " - "will not be used") - .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_bwd.json", "json file name to dump results"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} + success, + failure, + invalid_args, + no_instance, +}; // different threshold for different dtype template @@ -125,57 +48,82 @@ auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) return ck_tile::make_tuple(rtol, atol); } +extern template float fmha_bwd<2>(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); + template -bool run(const ck_tile::ArgParser& arg_parser) +bwd_result fmha_bwd_run(mode_enum mode, + ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + std::vector seqlen_qs, + std::vector seqlen_ks, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + bool i_perm, + bool o_perm, + float scale, + std::string bias_str, + bool use_dbias, + float p_drop, + uint64_t drop_seed, + uint64_t drop_offset, + bool drop_prefs, + std::string mask_str, + bool deterministic, + std::string init_method, + uint32_t seed, + int do_validation, + const ck_tile::stream_config& stream_config, + std::optional json = std::nullopt) { - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); - auto mode = static_cast(arg_parser.get_uint32("mode")); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + const std::string data_type = []() { + if constexpr(std::is_same_v) + return "fp16"; + else if constexpr(std::is_same_v) + return "bf16"; + else + static_assert(false); + }(); + if(nhead_k < 0) nhead_k = nhead; - if(nhead % nhead_k != 0) { std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; - return false; + return bwd_result::invalid_args; } - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k < 0) - seqlen_k = seqlen_q; - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}()); + auto next_seed = [&random_engine]() { return static_cast(random_engine()); }; + if(hdim_v < 0) hdim_v = hdim_q; - bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim - bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim - - float scale = arg_parser.get_float("scale"); if(scale == .0f) scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - bias_info bias = bias_info::decode(arg_parser.get_str("bias")); - bool use_dbias = arg_parser.get_bool("dbias"); - float p_drop = arg_parser.get_float("p_drop"); - uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); - uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); - bool drop_prefs = arg_parser.get_bool("drop_prefs"); + bias_info bias = bias_info::decode(bias_str); if(use_dbias && bias.type != bias_enum::elementwise_bias) { std::cerr << "dbias only exists when bias type is elementwise" << std::endl; - return false; + return bwd_result::invalid_args; } + std::vector seqlen_kpads; + std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = + generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine); + ck_tile::ignore = seqlen_kpads; +#if 0 + std::cout << "seqlen_qs: " << seqlen_qs << std::endl; + std::cout << "seqlen_ks: " << seqlen_ks << std::endl; +#endif + + mask_info mask = mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); if(p_drop < 0.0f || p_drop > 1.0f) { std::cerr << "The value of p_drop should be 0~1" << std::endl; - return false; + return bwd_result::invalid_args; } float p_undrop = 1.0 - p_drop; uint8_t p_undrop_in_uint8_t = @@ -188,29 +136,8 @@ bool run(const ck_tile::ArgParser& arg_parser) s_randval = true; } - mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); - - int init_method = arg_parser.get_int("init"); - std::optional seed = arg_parser.get_uint32("seed"); - if(*seed == 0) - { - seed.reset(); - } - - int stream_warmup = arg_parser.get_int("warmup"); - int stream_repeat = arg_parser.get_int("repeat"); - bool kname = arg_parser.get_bool("kname"); - bool deterministic = arg_parser.get_bool("deterministic"); - - ck_tile::stream_config stream_config{nullptr, - true, - /* log_level = */ (kname ? 1 : 0), - stream_warmup, - stream_repeat, - arg_parser.get_str("timer") == std::string("gpu")}; - - const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); - const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); using TypeConfig = FmhaBwdTypeConfig; @@ -283,10 +210,13 @@ bool run(const ck_tile::ArgParser& arg_parser) // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = - (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); - const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64; + (mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back()); + // Keep it equal to or smaller than minimal bn0 of all tiles in fmha_bwd.py + // TODO: add API for requesting kN0/nsplits/workspace_size? It is not safe to rely on internal + // implementation details in client code. + const ck_tile::index_t kN0 = 16; const ck_tile::index_t nsplits = deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1; @@ -331,23 +261,25 @@ bool run(const ck_tile::ArgParser& arg_parser) ? std::array{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q} : std::array{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q}); - if(init_method == 0) + if(init_method == "ui" || init_method == "0") { - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, next_seed()}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, next_seed()}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, next_seed()}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, next_seed()}( + bias_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, next_seed()}( + do_host); } - else if(init_method == 1) + else if(init_method == "uf" || init_method == "1") { - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(do_host); } - else if(init_method == 2) + else if(init_method == "tf" || init_method == "2") { ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); @@ -355,6 +287,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillTrigValue{}(bias_host); ck_tile::FillTrigValue{}(do_host); } + else + { + std::cerr << "Unknown value for init argument: " << init_method << std::endl; + return bwd_result::invalid_args; + } + if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -415,22 +353,19 @@ bool run(const ck_tile::ArgParser& arg_parser) else return layout_str(iperm_) + std::string("-") + layout_str(operm_); }; // clang-format on - const std::string prec = arg_parser.get_str("prec"); - std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch - << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k - << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias - << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval - << ", deterministic:" << deterministic << ", mask:" << mask << std::flush; + const std::size_t workspace_size_in_megabytes = + ck_tile::integer_divide_ceil(dq_acc_host.get_element_space_size_in_bytes(), 1024 * 1024); - std::size_t workspace_size = - dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024); - - if(deterministic == 1) - { - std::cout << "\nDeterministic mode ON: " << workspace_size - << " MByte memory workspace allocated" << std::endl; - } + std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm) + << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] + << "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale + << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop + << ", s_randval:" << s_randval << ", deterministic:" << deterministic + << (deterministic ? std::string(", workspace:") + + std::to_string(workspace_size_in_megabytes) + "MiB" + : "") + << ", mask:" << mask << std::flush; auto fmha_traits = fmha_bwd_traits{hdim_q, hdim_v, @@ -443,7 +378,6 @@ bool run(const ck_tile::ArgParser& arg_parser) s_randval, deterministic}; auto fmha_args = [&]() { - assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// 'nhead_stride_bias' are 0. @@ -572,20 +506,21 @@ bool run(const ck_tile::ArgParser& arg_parser) drop_seed_offset}; }(); - float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); + const float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); if(ave_time < 0) { std::cout << ", not supported yet" << std::flush << std::endl; - return false; + return bwd_result::no_instance; } - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " - << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush; + const float tflops = static_cast(flop) / 1.E9 / ave_time; + const float gb_per_sec = num_byte / 1.E6 / ave_time; + if(stream_config.time_kernel_) + { + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) + << gb_per_sec << " GB/s" << std::flush; + } bool pass = true; if(!do_validation) @@ -635,17 +570,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t nr = nhead / nhead_k; // clang-format off - // permute - if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); - else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); - // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); - // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); // clang-format on // reference @@ -760,18 +695,40 @@ bool run(const ck_tile::ArgParser& arg_parser) real_seqlen_k, mask.type == mask_enum::mask_top_left)); } + const ck_tile::HostTensor masked_s_host_ref = s_host_ref; ck_tile::reference_batched_softmax( s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); if(p_drop > 0) { p_dropped_hp_host_ref = p_hp_host_ref; - randval_host_ref.ForEach([&](auto& self, auto idx) { - self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); - }); + ck_tile::reference_batched_dropout_randval( + randval_host_ref, wb, drop_seed, drop_offset); ck_tile::reference_batched_dropout( p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); + + ck_tile::HostTensor randval_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_result.ForEach([&](auto& self, const auto& idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + masked_s_host_ref.ForEach([&](const auto& self, const auto& idx) { + // Ignore all masked values in validation check + if(std::isinf(self(idx))) + { + randval_host_ref(idx) = 0; + randval_host_result(idx) = 0; + } + }); + bool cur_pass = ck_tile::check_err(randval_host_result, + randval_host_ref, + "DROPOUT RANDVAL Error: Incorrect results!"); + pass &= cur_pass; + if(!cur_pass) + { + break; + } } else { @@ -783,11 +740,11 @@ bool run(const ck_tile::ArgParser& arg_parser) p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n // clang-format off - // permute - if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); - else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); + // permute + if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); + else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); - lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); }); + lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); }); // clang-format on q_host_refs.push_back(q_host_ref); @@ -816,8 +773,7 @@ bool run(const ck_tile::ArgParser& arg_parser) dbias_buf.SetZero(); dq_acc_buf.SetZero(); - ck_tile::stream_config stream_config_v{ - nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); @@ -855,8 +811,8 @@ bool run(const ck_tile::ArgParser& arg_parser) {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o // clang-format off - if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); - else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); + if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); + else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); // clang-format on // dP = dO@V x Z w/ dropout @@ -934,21 +890,21 @@ bool run(const ck_tile::ArgParser& arg_parser) {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n // clang-format off - // permute - if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); - else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + // permute + if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); - if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); }); - else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); }); - if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); - else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); - if(use_dbias) - { - if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); - else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); - } + if(use_dbias) + { + if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + } // clang-format on auto [rtol, atol] = get_elimit(hdim_q, hdim_v); @@ -994,10 +950,10 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } - if(arg_parser.get_int("json") == 1) + if(json) { dump_fmha_bwd_json_results( - arg_parser.get_str("jsonfile"), + *json, data_type, mode == mode_enum::batch ? "batch" : "group", i_perm ? "true" : "false", @@ -1005,8 +961,8 @@ bool run(const ck_tile::ArgParser& arg_parser) batch, nhead, nhead_k, - seqlen_q, - seqlen_k, + seqlen_qs[0], + seqlen_ks[0], hdim_q, hdim_v, scale, @@ -1027,30 +983,12 @@ bool run(const ck_tile::ArgParser& arg_parser) : "mask_generic"))), mask.left, mask.right, - workspace_size, + workspace_size_in_megabytes, pass, ave_time, tflops, gb_per_sec); } - return pass; -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") - { - return run(arg_parser) ? 0 : -2; - } - else if(data_type == "bf16") - { - return run(arg_parser) ? 0 : -2; - } - - return -3; + + return pass ? bwd_result::success : bwd_result::failure; } diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp similarity index 72% rename from example/ck_tile/01_fmha/fmha_fwd.cpp rename to example/ck_tile/01_fmha/fmha_fwd_runner.hpp index dd7c444557..397245ab32 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1,11 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "fmha_fwd.hpp" +#pragma once + #include "ck_tile/host.hpp" #include "ck_tile/ref/naive_attention.hpp" -#include "mask.hpp" -#include "rotary.hpp" +#include "fmha_fwd.hpp" #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" @@ -24,128 +24,13 @@ #error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()" #endif -template -std::ostream& operator<<(std::ostream& os, const std::vector& v) +enum class fwd_result { - using size_type = typename std::vector::size_type; - - os << "["; - for(size_type idx = 0; idx < v.size(); ++idx) - { - if(0 < idx) - { - os << ", "; - } - os << v[idx]; - } - return os << "]"; -} - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)") - .insert("mode", "0", "kernel mode. 0:batch, 1:group") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "-1", - "num of head, for k/v, -1 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert( - "s", - "3328", - "seqlen_q. if group-mode, means the average value of seqlen_q\n" - "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" - "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") - .insert("s_k", "-1", "seqlen_k (including new key/value), -1 means equal to s") - .insert("s_knew", - "0", - "seqlen_k for new key/value, 0 means not to use this at all; " - "-1 to choose s_knew in [1, s] randomly.") - .insert("s_kpad", - "-1", - "seqlen_k stride between 2 batches, currently used in group-mode only\n" - "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" - "along seqlen, instead of packed. same as xformer kv_padding") - .insert("d", "128", "head dim for q, k") - .insert("d_v", "-1", "head dim for v, -1 means equal to d") - .insert("scale_s", - "0", - "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" - "note when squant=1, this value will be modified by range_q/k") - .insert("logits_soft_cap", "0", "attention logits soft capping value.") - .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") - .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") - .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") - .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") - .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") - .insert("squant", - "auto", - "if using static quantization fusion or not. auto: fp8 will default use squant, " - "other will not\n" - "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " - "P and O.\n" - "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " - "range_p, range_o") - .insert("iperm", - "1", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "1", "permute output") - .insert("bias", - "n", - "n or 0, no bias\n" - "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" - "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("mask", - "0", - "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" - "'t', top-left causal mask, 'b', bottom-r causal mask\n" - "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" - "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" - "'xt:window_size', xformer style masking from top-left, window_size negative is " - "causal, positive is swa\n" - "'xb:window_size', xformer style masking from bottom-r, window_size negative is " - "causal, positive is swa\n" - "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " - "now)") - .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") - .insert("lse", "0", "0 not store lse, 1 store lse") - .insert("kname", "0", "if set to 1 will print kernel name") - .insert("init", - "uf", - "init method. ui, uniform random int, ni, normalized random int\n" - "uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, " - "quantization") - .insert("seed", - "11939", - "random seed used for initializing input tensors. 0 for " - "non-deterministic seed") - .insert("p_drop", "0", "0~1 probability of dropout") - .insert("drop_seed", "1", "seed for random number generator") - .insert("drop_offset", "0", "offset for random number generator") - .insert("drop_prefs", - "0", - "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert( - "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") - .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") - .insert("num_splits", - "1", - "# of splits for key/value. 0 to determine actual number by heuristic") - .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe") - .insert("cache_batch_idx", "0", "whether to use index map to the kvcache") - .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel") - .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} + success, + failure, + invalid_args, + no_instance, +}; // different threshold for different dtype template @@ -247,35 +132,72 @@ int override_num_splits_if_necessary( } template -bool run(const ck_tile::ArgParser& arg_parser) +fwd_result fmha_fwd_run(mode_enum mode, + ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + std::vector seqlen_qs, + std::vector seqlen_ks, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t seqlen_knew, + std::vector seqlen_kpads, + ck_tile::index_t rotary_dim, + bool i_perm, + bool o_perm, + float scale_s, + float logits_soft_cap, + bool is_v_rowmajor, + bool lse, + ck_tile::index_t page_block_size, + bool use_cache_batch_idx, + std::string bias_str, + float p_drop, + uint64_t drop_seed, + uint64_t drop_offset, + bool drop_prefs, + std::string mask_str, + float range_q, + float range_k, + float range_v, + float range_p, + float range_o, + bool squant, + bool is_rotary_interleaved, + ck_tile::index_t num_splits, + std::string init_method, + uint32_t seed, + int do_validation, + const ck_tile::stream_config& stream_config, + std::optional json = std::nullopt) { - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); - auto mode = static_cast(arg_parser.get_uint32("mode")); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + const std::string data_type = []() { + if constexpr(std::is_same_v) + return "fp16"; + else if constexpr(std::is_same_v) + return "bf16"; + else if constexpr(std::is_same_v) + return "fp8"; + else if constexpr(std::is_same_v) + return "bf8"; + else + static_assert(false); + }(); + if(nhead_k < 0) nhead_k = nhead; - if(nhead % nhead_k != 0) { std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; - return false; + return fwd_result::invalid_args; } - std::optional seed = arg_parser.get_uint32("seed"); - if(*seed == 0) - { - seed.reset(); - } + std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}()); + auto next_seed = [&random_engine]() { return static_cast(random_engine()); }; - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); if(hdim_v < 0) hdim_v = hdim_q; - ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); #if !CK_TILE_FMHA_FWD_APPENDKV_API if(seqlen_knew != 0) { @@ -286,17 +208,16 @@ bool run(const ck_tile::ArgParser& arg_parser) #endif if(seqlen_knew < 0) { - seqlen_knew = randint(1, arg_parser.get_int("s"), seed); + seqlen_knew = randint(1, seqlen_qs[0], random_engine); } - ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); if constexpr(!(std::is_same_v || std::is_same_v)) { if(0 < rotary_dim) { std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; - return false; + return fwd_result::invalid_args; } } #if !CK_TILE_FMHA_FWD_APPENDKV_API @@ -317,15 +238,14 @@ bool run(const ck_tile::ArgParser& arg_parser) if(!(rotary_dim <= hdim_q)) { std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; - return false; + return fwd_result::invalid_args; } else if(!(rotary_dim % 16 == 0)) { std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; - return false; + return fwd_result::invalid_args; } - ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); #if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ CK_TILE_FMHA_FWD_PAGEDKV_API)) if(0 < page_block_size) @@ -339,10 +259,9 @@ bool run(const ck_tile::ArgParser& arg_parser) { std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" << std::endl; - return false; + return fwd_result::invalid_args; } - bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); #if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API) if(use_cache_batch_idx) { @@ -371,14 +290,23 @@ bool run(const ck_tile::ArgParser& arg_parser) #endif const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); - auto [seqlen_qs, seqlen_ks, seqlen_kpads] = - decode_seqlen(mode, - batch, - arg_parser.get_str("s"), - arg_parser.get_str("s_k"), - arg_parser.get_str("s_kpad"), - /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, - need_append_kvcache); + std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = + generate_missing_seqlens(mode, + batch, + seqlen_qs, + seqlen_ks, + seqlen_kpads, + /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, + need_append_kvcache, + random_engine); + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb]) + { + std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl; + return fwd_result::invalid_args; + } + } // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -387,64 +315,32 @@ bool run(const ck_tile::ArgParser& arg_parser) [&](auto seqlen_k) { return seqlen_k - seqlen_knew; }); #if 0 - // clang-format off - std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; - std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; - std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl; - // clang-format on + std::cout << "seqlen_qs: " << seqlen_qs << std::endl; + std::cout << "seqlen_ks: " << seqlen_ks << std::endl; + std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl; + std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl; #endif - bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim - bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim - - float scale_s = arg_parser.get_float("scale_s"); if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - const float logits_soft_cap = arg_parser.get_float("logits_soft_cap"); + bias_info bias = bias_info::decode(bias_str); - std::string squant_str = arg_parser.get_str("squant"); - bool squant = [&]() { - if(squant_str == "auto") - { - if(data_type == "fp8") - return true; - else - return false; - } - else - return atoi(squant_str.c_str()) != 0 ? true : false; - }(); - - std::string vlayout = arg_parser.get_str("vlayout"); - bool lse = arg_parser.get_bool("lse"); - - bias_info bias = bias_info::decode(arg_parser.get_str("bias")); - mask_info mask = mask_info::decode( - arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore - - float p_drop = arg_parser.get_float("p_drop"); - uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); - uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); - bool drop_prefs = arg_parser.get_bool("drop_prefs"); + mask_info mask = + mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore if(p_drop < 0.0f || p_drop > 1.0f) { std::cerr << "The value of p_drop should be 0~1" << std::endl; - return false; + return fwd_result::invalid_args; } bool s_randval = false; - if(p_drop > 0.0f && do_validation != 0) + if(p_drop > 0.0f && do_validation) { s_randval = true; } - std::string init_method = arg_parser.get_str("init"); - - const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); - - ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); #if !CK_TILE_FMHA_FWD_SPLITKV_API if(num_splits != 1) { @@ -453,17 +349,6 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - int stream_warmup = arg_parser.get_int("warmup"); - int stream_repeat = arg_parser.get_int("repeat"); - bool kname = arg_parser.get_bool("kname"); - - ck_tile::stream_config stream_config{nullptr, - true, - /* log_level = */ (kname ? 1 : 0), - stream_warmup, - stream_repeat, - arg_parser.get_str("timer") == std::string("gpu")}; - const auto seqstart_q_host = to_seqstarts(seqlen_qs); const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); @@ -482,12 +367,6 @@ bool run(const ck_tile::ArgParser& arg_parser) using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); @@ -549,12 +428,12 @@ bool run(const ck_tile::ArgParser& arg_parser) if(128 < num_splits) { std::cerr << "num_splits greater than 128 is not supported" << std::endl; - return false; + return fwd_result::invalid_args; } #if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < p_drop && (1 < num_splits || use_kvcache)) { - std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" + std::cerr << "dropout is not supported by split-kv kernels. ignoring the 'p_drop' option" << std::endl; p_drop = 0.0f; } @@ -571,8 +450,6 @@ bool run(const ck_tile::ArgParser& arg_parser) return std::array{b, s, h, d}; }; - bool is_v_rowmajor = vlayout == std::string("r"); - // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = @@ -617,7 +494,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1}); auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( - std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed); + std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, next_seed()); ck_tile::HostTensor lse_acc_host( 1 < num_splits || use_kvcache @@ -654,39 +531,41 @@ bool run(const ck_tile::ArgParser& arg_parser) if(init_method == "ui" || init_method == "0") { - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}( + bias_host); } else if(init_method == "ni") { - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}( + bias_host); } else if(init_method == "uf" || init_method == "1") { - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(knew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(vnew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(bias_host); } else if(init_method == "nf") { - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(knew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(vnew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(bias_host); } else if(init_method == "tf" || init_method == "2") { @@ -697,20 +576,32 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == "ufq" || init_method == "uf:q" || - init_method == "3") // suitable for fp8 quantization + else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") { - ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, seed}(q_host); - ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(k_host); - ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(knew_host); - ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(v_host); - ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(vnew_host); + // suitable for fp8 quantization + if(!squant) + { + std::cerr << "init method " << init_method << " can not be used without quantization" + << std::endl; + return fwd_result::invalid_args; + } + ck_tile::FillUniformDistribution{0.f, q_dtype_max, next_seed()}(q_host); + ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(k_host); + ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(knew_host); + ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(v_host); + ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(vnew_host); // bias_fp8 = qscale_bias * bias_fp32 float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); - // Assume bias is in [-1.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); + // Assume bias is in [0.f, 1.f] in original fp32 + ck_tile::FillUniformDistribution{0.f, qscale_bias, next_seed()}(bias_host); } + else + { + std::cerr << "Unknown value for init argument: " << init_method << std::endl; + return fwd_result::invalid_args; + } + if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -729,8 +620,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } } } - iota_shuffle(block_table_host.begin(), block_table_host.end(), 0); - iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0); + iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); + iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); @@ -790,15 +681,15 @@ bool run(const ck_tile::ArgParser& arg_parser) else return layout_str(iperm_) + std::string("-") + layout_str(operm_); }; // clang-format on - const std::string prec = arg_parser.get_str("prec"); - std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch - << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] + std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm) + << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] + << "/" << seqlen_ks[0] << (seqlen_kpads[0] < 0 ? "" : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout; + << ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c"); #if CK_TILE_FMHA_FWD_APPENDKV_API if(0 < rotary_dim) { @@ -850,13 +741,12 @@ bool run(const ck_tile::ArgParser& arg_parser) else if constexpr(std::is_same_v>) { - traits.use_pagedkv = use_kvcache; + traits.use_pagedkv = (0 < page_block_size); } } }; const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { - assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// 'nhead_stride_bias' are 0. @@ -1089,10 +979,35 @@ bool run(const ck_tile::ArgParser& arg_parser) #endif return 0.0f; }(); + if(appendkv_ave_time < 0.0f) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return fwd_result::no_instance; + } const float fwd_ave_time = [&] { +#if CK_TILE_FMHA_FWD_PAGEDKV_API + if(1 == num_splits && use_kvcache) + { + fmha_fwd_pagedkv_traits fmha_pagedkv_traits; + init_traits(fmha_pagedkv_traits); + + fmha_fwd_pagedkv_args fmha_pagedkv_args; + init_args(fmha_pagedkv_args); + + const float ave_time = + fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits && use_kvcache) + // If there is no instance for these args, fallback to fmha_fwd_splitkv + if(ave_time >= 0.0f) + return ave_time; +#else + return ave_time; +#endif + } +#endif // CK_TILE_FMHA_FWD_PAGEDKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(1 < num_splits || use_kvcache) { fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits); @@ -1102,19 +1017,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); } -#endif -#if CK_TILE_FMHA_FWD_PAGEDKV_API - if(use_kvcache) - { - fmha_fwd_pagedkv_traits fmha_pagedkv_traits; - init_traits(fmha_pagedkv_traits); - - fmha_fwd_pagedkv_args fmha_pagedkv_args; - init_args(fmha_pagedkv_args); - - return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); - } -#endif +#endif // CK_TILE_FMHA_FWD_SPLITKV_API fmha_fwd_traits fmha_traits; init_traits(fmha_traits); @@ -1123,22 +1026,21 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd(fmha_traits, fmha_args, stream_config); }(); - - if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f) + if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; - return false; + return fwd_result::no_instance; } - const float ave_time = (appendkv_ave_time + fwd_ave_time); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " - << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush << std::endl; + const float ave_time = appendkv_ave_time + fwd_ave_time; + const float tflops = static_cast(flop) / 1.E9 / ave_time; + const float gb_per_sec = num_byte / 1.E6 / ave_time; + if(stream_config.time_kernel_) + { + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) + << gb_per_sec << " GB/s" << std::flush; + } bool pass = true; if(do_validation == 0) @@ -1197,20 +1099,21 @@ bool run(const ck_tile::ArgParser& arg_parser) } else { - o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); + constexpr bool supports_squant = std::is_same_v; + auto p_compute_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(supports_squant) return ck_tile::scales{scale_p}; else return ck_tile::identity{}; }(); auto oacc_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(supports_squant) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); else @@ -1252,140 +1155,146 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t nr = nhead / nhead_k; // clang-format off - // permute - if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); - else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + // clang-format on #if CK_TILE_FMHA_FWD_APPENDKV_API - // optionally apply RoPE to the q_host_ref - if(0 < rotary_dim) - { - decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); - - auto [rotary_cos_slice, rotary_sin_slice] = - slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); - - ck_tile::reference_batched_rotary_position_embedding( - q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro, - /*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask); - - q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); - } -#endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API - if(0 < page_block_size) { - if(i_perm) { - k_host_ref.ForEach([&](auto& self, auto i) { - self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); - }); - } else { - k_host_ref.ForEach([&](auto& self, auto i) { - self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); - }); - } - } else -#endif - { - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); - } - -#if CK_TILE_FMHA_FWD_APPENDKV_API - // copy Knew to the end of K - if(0 < seqlen_knew) - { - ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); - if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); - else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); - - // optionally apply RoPE to the knew_host_ref - auto* real_knew_host_ref = &knew_host_ref; - std::optional knew_host_ref_ro; + // optionally apply RoPE to the q_host_ref if(0 < rotary_dim) { - knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); + decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); - auto [rotary_cos_slice, rotary_sin_slice] = - slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin( + rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); ck_tile::reference_batched_rotary_position_embedding( - knew_host_ref, + q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, - knew_host_ref_ro.value()); + q_host_ref_ro, + /*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask); - real_knew_host_ref = &knew_host_ref_ro.value(); + q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } - - (*real_knew_host_ref).ForEach([&](auto& self, auto i) { - k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); - }); - } #endif #if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API - if(0 < page_block_size) { - if(is_v_rowmajor) { - if(i_perm) { - v_host_ref.ForEach([&](auto& self, auto i) { - self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); - }); - } else { - v_host_ref.ForEach([&](auto& self, auto i) { - self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); - }); - } + if(0 < page_block_size) + { + // clang-format off + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); }); + // clang-format on } else - { - if(i_perm) { - v_host_ref.ForEach([&](auto& self, auto i) { - self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); - }); - } else { - v_host_ref.ForEach([&](auto& self, auto i) { - self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); - }); - } - } - } else #endif - { - if(is_v_rowmajor) { - // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); - // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); - } - else { - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + // clang-format off + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + // clang-format on } - } #if CK_TILE_FMHA_FWD_APPENDKV_API - // copy Vnew to the end of V - if(0 < seqlen_knew) - { - ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); - if(is_v_rowmajor) + // copy Knew to the end of K + if(0 < seqlen_knew) { - if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); - else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); + ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); + // clang-format off + if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); + else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); + // clang-format on + + // optionally apply RoPE to the knew_host_ref + auto* real_knew_host_ref = &knew_host_ref; + std::optional knew_host_ref_ro; + if(0 < rotary_dim) + { + knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin( + rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); + + ck_tile::reference_batched_rotary_position_embedding(knew_host_ref, + rotary_cos_slice, + rotary_sin_slice, + is_rotary_interleaved, + knew_host_ref_ro.value()); + + real_knew_host_ref = &knew_host_ref_ro.value(); + } + + (*real_knew_host_ref).ForEach([&](auto& self, auto i) { + k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); + }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API + if(0 < page_block_size) + { + if(is_v_rowmajor) + { + // clang-format off + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); }); + // clang-format on + } + else + { + // clang-format off + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); }); + // clang-format on + } } else +#endif { - if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); - else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); + if(is_v_rowmajor) + { + // clang-format off + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + } + else + { + // clang-format off + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + // clang-format on + } } - vnew_host_ref.ForEach([&](auto& self, auto i) { - v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); - }); - } +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Vnew to the end of V + if(0 < seqlen_knew) + { + ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); + if(is_v_rowmajor) + { + // clang-format off + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); + // clang-format on + } + else + { + // clang-format off + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); + // clang-format on + } + + vnew_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); + }); + } #endif - // clang-format on // reference ck_tile:: @@ -1412,10 +1321,8 @@ bool run(const ck_tile::ArgParser& arg_parser) // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off - if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); - else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); // clang-format on // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, @@ -1509,6 +1416,7 @@ bool run(const ck_tile::ArgParser& arg_parser) real_seqlen_k, mask.type == mask_enum::mask_top_left)); } + const ck_tile::HostTensor masked_s_host_ref = s_host_ref; if(lse) { ck_tile:: @@ -1526,11 +1434,32 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::HostTensor randval_host_ref( {nhead, real_seqlen_q, real_seqlen_k}); - randval_host_ref.ForEach([&](auto& self, auto idx) { - self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); - }); + ck_tile::reference_batched_dropout_randval( + randval_host_ref, wb, drop_seed, drop_offset); ck_tile::reference_batched_dropout( p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + + ck_tile::HostTensor randval_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_result.ForEach([&](auto& self, const auto& idx) { + self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); + }); + masked_s_host_ref.ForEach([&](const auto& self, const auto& idx) { + // Ignore all masked values in validation check + if(std::isinf(self(idx))) + { + randval_host_ref(idx) = 0; + randval_host_result(idx) = 0; + } + }); + bool cur_pass = ck_tile::check_err(randval_host_result, + randval_host_ref, + "DROPOUT RANDVAL Error: Incorrect results!"); + pass &= cur_pass; + if(!cur_pass) + { + break; + } } ck_tile::reference_batched_gemm( @@ -1543,9 +1472,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off - // permute - if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); - else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on auto [rtol, atol] = get_elimit(init_method); @@ -1597,10 +1526,10 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } - if(arg_parser.get_int("json") == 1) + if(json) { - dump_fmha_fwd_json_results(arg_parser.get_str("jsonfile"), - prec, + dump_fmha_fwd_json_results(*json, + data_type, mode == mode_enum::batch ? "batch" : "group", io_layout(i_perm, o_perm), batch, @@ -1618,35 +1547,12 @@ bool run(const ck_tile::ArgParser& arg_parser) bias.type == bias_enum::elementwise_bias ? "elementwise_bias" : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - vlayout, + is_v_rowmajor ? "r" : "c", pass, ave_time, tflops, gb_per_sec); } - return pass; -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") - { - return run(arg_parser) ? 0 : -2; - } - else if(data_type == "bf16") - { - return run(arg_parser) ? 0 : -2; - } - else if(data_type == "fp8") - { - return run(arg_parser) ? 0 : -2; - } - - return -3; + return pass ? fwd_result::success : fwd_result::failure; } diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index af38ff0214..2dfe0e7c52 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -39,6 +39,7 @@ struct mask_info os << "g(" << y << ":" << x << ")"; } } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) { ck_tile::index_t x_total = seqlen_k; @@ -54,7 +55,7 @@ struct mask_info if(t == "xt" || t == "xb") { // xformer style sliding window attn from top-left - ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t window_size = std::stoi(v); ck_tile::index_t left_size = -1; ck_tile::index_t right_size = 0; if(window_size > 0) @@ -71,18 +72,15 @@ struct mask_info tmp.left = left_size; tmp.right = right_size; } - else + else if(t == "t" || t == "b" || t == "g") { auto found_1 = v.find(","); if(found_1 == std::string::npos) { - printf("not supported value %s, %s\n", v.c_str(), str.c_str()); - assert(0); + throw std::invalid_argument("invalid mask value: " + str); } - tmp.type = mask_enum::window_generic; - ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); - ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); - // TODO: some validation + ck_tile::index_t v0 = std::stoi(v.substr(0, found_1)); + ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1)); if(t == "t") { tmp.type = mask_enum::mask_top_left; @@ -105,53 +103,45 @@ struct mask_info } else if(t == "g") { + tmp.type = mask_enum::window_generic; tmp.y = v0; tmp.x = v1; tmp.left = v0; // TODO: don't use this? tmp.right = v1; } - else - { - printf("not supported type %s, %s\n", t.c_str(), str.c_str()); - assert(0); - } } + else + { + throw std::invalid_argument("invalid mask value: " + str); + } + } + else if(str == "0") + { + tmp.type = mask_enum::no_mask; + } + else if(str == "1" || str == "t") + { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + } + else if(str == "2" || str == "b") + { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; } else { - auto set_causal_top_left = [&]() { - tmp.type = mask_enum::mask_top_left; - tmp.y = seqlen_q; - tmp.x = 1; - tmp.left = -1; - tmp.right = 0; - }; - auto set_causal_bottom_right = [&]() { - tmp.type = mask_enum::mask_bottom_right; - tmp.y = seqlen_q; - tmp.x = seqlen_k - seqlen_q + 1; - tmp.left = -1; - tmp.right = 0; - }; - if(str == "t") - set_causal_top_left(); - else if(str == "b") - set_causal_bottom_right(); - else - { - tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::mask_top_left) - { - set_causal_top_left(); - } - else if(tmp.type == mask_enum::mask_bottom_right) - { - set_causal_bottom_right(); - } - } + throw std::invalid_argument("invalid mask value: " + str); } return tmp; } + ck_tile::index_t get_unmaskarea() const { if(type == mask_enum::no_mask) @@ -168,6 +158,7 @@ struct mask_info } return area; } + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) { mi.serialize(os); diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index faf3f08437..7f44d87180 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -1,11 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include -#include #include #include #include @@ -28,6 +27,23 @@ std::ostream& operator<<(std::ostream& stream, mode_enum mode) return stream << (mode == mode_enum::batch ? "batch" : "group"); } +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + std::vector to_seqstarts(ck_tile::span seqlens) { std::vector seqstarts = {0}; @@ -39,12 +55,13 @@ std::vector to_seqstarts(ck_tile::span seqlens) return seqstarts; } +template std::vector generate_seqlens(mode_enum mode, unsigned count, int32_t seqlen_avg, - int32_t seqlen_min = -1, // if not negative, clamp min - int32_t seqlen_max = -1, // if not negative, clamp max - std::optional seed = std::nullopt) + int32_t seqlen_min, // if not negative, clamp min + int32_t seqlen_max, // if not negative, clamp max + RandomEngine& random_engine) { assert(0 < count); @@ -58,7 +75,6 @@ std::vector generate_seqlens(mode_enum mode, { using size_type = std::vector::size_type; - std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); std::uniform_int_distribution idx_dist(0, count - 1); auto next_idx = std::bind(idx_dist, std::ref(random_engine)); @@ -89,43 +105,31 @@ std::vector generate_seqlens(mode_enum mode, return seqlens; } -std::vector generate_seqstarts(mode_enum mode, - unsigned count, - int32_t seqlen_avg, - int32_t seqlen_min = -1, - int32_t seqlen_max = -1, - std::optional seed = std::nullopt) -{ - return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed)); -} - // return random integer generated uniformly in range [low, high] -template -auto randint(Int low, Int high, std::optional seed = std::nullopt) - -> std::enable_if_t, Int> +template +auto randint(Int low, + Int high, + RandomEngine& random_engine) -> std::enable_if_t, Int> { - std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); std::uniform_int_distribution dist(low, high); - return dist(engine); + return dist(random_engine); } // return random integers generated uniformly in range [low, high] -template +template auto randints(ForwardIterator first, ForwardIterator last, Int low, Int high, - std::optional seed = std::nullopt) - -> std::enable_if_t> + RandomEngine& random_engine) -> std::enable_if_t> { - std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); std::uniform_int_distribution dist(low, high); - std::generate(first, last, [&] { return dist(engine); }); + std::generate(first, last, [&] { return dist(random_engine); }); } /* - * decode the seqlen string from cmdline + * generate missing values in *_val randomly when the number of values is smaller than batch * example (assume batch=3) * q_val=1,2,3 k_val=4,5,6 -> OK * q_val=1,2,3 -> OK, k same as q @@ -136,23 +140,23 @@ auto randints(ForwardIterator first, * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q * q_val=1,2 k_val=4 -> not OK, k must have same splits with q */ +template std::tuple, std::vector, std::vector> -decode_seqlen(mode_enum mode, - ck_tile::index_t batch, - std::string q_val, - std::string k_val, - std::string k_pad_val, - ck_tile::index_t seqlen_k_min = 0, - bool need_append_kvcache = false, - std::optional seed = std::nullopt) +generate_missing_seqlens(mode_enum mode, + ck_tile::index_t batch, + const std::vector& q_val, + const std::vector& k_val, + const std::vector& k_pad_val, + ck_tile::index_t seqlen_k_min, + bool need_append_kvcache, + RandomEngine& random_engine) { -#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) if(mode == mode_enum::batch) { - ck_tile::index_t q = _S2I_(q_val); - ck_tile::index_t k = _S2I_(k_val); + ck_tile::index_t q = q_val[0]; + ck_tile::index_t k = k_val[0]; auto s_q = std::vector(batch, q); auto s_k = [&] { @@ -166,7 +170,7 @@ decode_seqlen(mode_enum mode, seqlen_ks.end(), seqlen_k_min, seqlen_k_max, - seed); + random_engine); return seqlen_ks; } @@ -187,25 +191,19 @@ decode_seqlen(mode_enum mode, } else { - ck_tile::index_t idx = 0; - std::string::size_type pos_q = 0; - std::string::size_type pos_k = 0; - std::string::size_type pos_kp = 0; std::vector s_q; std::vector s_k; std::vector s_kpad; - while(true) + ck_tile::index_t idx = 0; + for(; idx < std::min(static_cast(q_val.size()), batch); ++idx) { - auto found_q = q_val.find(',', pos_q); - auto found_k = k_val.find(',', pos_k); - auto found_kp = k_pad_val.find(',', pos_kp); - - ck_tile::index_t q = _S2I_( - q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q)); - ck_tile::index_t k = _S2I_( - k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k)); - ck_tile::index_t kp = _S2I_(k_pad_val.substr( - pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); + ck_tile::index_t q = q_val[idx]; + ck_tile::index_t k = + k_val[std::min(idx, static_cast(k_val.size()) - 1)]; + ck_tile::index_t kp = + k_pad_val.empty() + ? -1 + : k_pad_val[std::min(idx, static_cast(k_pad_val.size()) - 1)]; s_q.push_back(q); s_k.push_back(k < 0 ? q : k); @@ -219,21 +217,13 @@ decode_seqlen(mode_enum mode, << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; throw std::runtime_error(msg.str()); } - - idx++; - if(found_q == std::string::npos || idx >= batch) - { - break; - } - pos_q = found_q + 1; - pos_k = found_k == std::string::npos ? pos_k : found_k + 1; - pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1; } if(idx < batch) { - auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed); - auto rem_k = - generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed); + auto rem_q = + generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine); + auto rem_k = generate_seqlens( + mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine); s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); @@ -241,26 +231,14 @@ decode_seqlen(mode_enum mode, } return std::make_tuple(s_q, s_k, s_kpad); } -#undef _S2I_ } -int env_get_int(const char* var_name, int default_int) -{ - char* v = getenv(var_name); - int r = default_int; - if(v) - r = std::atoi(v); - return r; -} - -template +template std::enable_if_t> iota_shuffle(RandomAccessIterator first, RandomAccessIterator last, Int value, - std::optional seed = std::nullopt) + RandomEngine& random_engine) { std::iota(first, last, value); - - std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); - std::shuffle(first, last, engine); + std::shuffle(first, last, random_engine); } diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index 04ca950641..890e507894 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -399,9 +399,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0) } mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa } - // The value is smaller than min f8 denormal and results in zero (the early exit also prevents + // The value is <= than min f8 denormal/2 and results in zero (the early exit also prevents // an undefined behavior of bit shifts >= type width). - if(exponent_diff > DstT_mant) + if(exponent_diff > DstT_mant + 1) { return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant)); } diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 41f5200413..86110d57ec 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -18,6 +18,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/ranges.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp" +#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp" #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" diff --git a/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp new file mode 100644 index 0000000000..2a02adaee3 --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void +reference_batched_dropout_randval(HostTensor& randval_b_m_n, + index_t batch, + uint64_t drop_seed, + uint64_t drop_offset) +{ + const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0]; + const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1]; + const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2]; + + static_assert(std::is_same_v); + + // BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the + // order of values in the bigger 32x32 tile must be the same because fwd and bwd may use + // different warp gemms (16x16 or 32x32). + // To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is + // WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor). + // Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8: + // C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4) + // C j: (lane % 32) + // With SFactor = 2 it becomes: + // C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8) + // C j: (lane % 32) + + constexpr index_t max_warp_size = 64; + constexpr index_t warp_gemm_mn = 32; + + const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn); + const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn); + + auto f = [&](index_t i_h, index_t row, index_t col) { + uint2 rowcol = make_uint2(row, col); + for(index_t lane = 0; lane < max_warp_size; lane++) + { + philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane); + + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + for(auto r = 0; r < 16; r++) + { + index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8); + index_t j = (lane % 32); + index_t m = row * warp_gemm_mn + i; + index_t n = col * warp_gemm_mn + j; + + if(m < real_seqlen_q && n < real_seqlen_k) + { + randval_b_m_n(i_h, m, n) = random_uint8_t[r]; + } + } + } + }; + + make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/utility/json_dump.hpp b/include/ck_tile/utility/json_dump.hpp index 05d6a66024..d7c96d77b8 100644 --- a/include/ck_tile/utility/json_dump.hpp +++ b/include/ck_tile/utility/json_dump.hpp @@ -611,7 +611,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename, float p_drop, bool lse, bool squant, - const std::string& bais, + const std::string& bias, const std::string& vlayout, bool pass, float ave_time, @@ -636,7 +636,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename, ADD_KEY_VALUE("p_drop", p_drop); ADD_KEY_VALUE("lse", lse); ADD_KEY_VALUE("squant", squant); - ADD_KEY_VALUE("bias", bais); + ADD_KEY_VALUE("bias", bias); ADD_KEY_VALUE("vlayout", vlayout); ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a2196ad2b2..947d5136be 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -38,6 +38,11 @@ set(REGRESSION_TESTS test_conv_tensor_rearrange test_gemm_mx test_ck_tile_batched_transpose + test_ck_tile_fmha_bwd_bf16 + test_ck_tile_fmha_bwd_fp16 + test_ck_tile_fmha_fwd_bf16 + test_ck_tile_fmha_fwd_fp16 + test_ck_tile_fmha_fwd_fp8 ) function(add_test_executable TEST_NAME) diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 1dde6ccdf7..993df2ec40 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -25,3 +25,4 @@ add_subdirectory(utility) add_subdirectory(reduce) add_subdirectory(epilogue) add_subdirectory(atomic_add_op) +add_subdirectory(fmha) diff --git a/test/ck_tile/data_type/test_fp8.cpp b/test/ck_tile/data_type/test_fp8.cpp index 49fd68591f..cb3e8ae53e 100644 --- a/test/ck_tile/data_type/test_fp8.cpp +++ b/test/ck_tile/data_type/test_fp8.cpp @@ -94,7 +94,9 @@ TYPED_TEST(ConvertTest, ToFp8) EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'0000'000); EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b1'0000'000); - // All values smaller than min f8 subnormal must be converted to f8 zero + // All values <= min f8 subnormal/2 must be converted to f8 zero + EXPECT_EQ(c(+0.001953125f * 0.6f), 0b0'0000'001); + EXPECT_EQ(c(-0.001953125f * 0.6f), 0b1'0000'001); constexpr int src_min_subnorm_exp = -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); constexpr int dst_min_subnorm_exp = @@ -176,7 +178,9 @@ TYPED_TEST(ConvertTest, ToFp8) EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'0000'000); EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b0'0000'000); - // All values smaller than min f8 subnormal must be converted to f8 zero + // All values <= min f8 subnormal/2 must be converted to f8 zero + EXPECT_EQ(c(+0.0009765625f * 0.6f), 0b0'0000'001); + EXPECT_EQ(c(-0.0009765625f * 0.6f), 0b1'0000'001); constexpr int src_min_subnorm_exp = -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); constexpr int dst_min_subnorm_exp = @@ -282,7 +286,9 @@ TYPED_TEST(ConvertTest, ToBf8) EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'00000'00); EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b1'00000'00); - // All values smaller than min f8 subnormal must be converted to f8 zero + // All values <= min f8 subnormal/2 must be converted to f8 zero + EXPECT_EQ(c(+1.52587890625e-05f * 0.6f), 0b0'0000'001); + EXPECT_EQ(c(-1.52587890625e-05f * 0.6f), 0b1'0000'001); constexpr int src_min_subnorm_exp = -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); constexpr int dst_min_subnorm_exp = @@ -373,7 +379,9 @@ TYPED_TEST(ConvertTest, ToBf8) EXPECT_EQ(c(+ck_tile::numeric::denorm_min()), 0b0'00000'00); EXPECT_EQ(c(-ck_tile::numeric::denorm_min()), 0b0'00000'00); - // All values smaller than min f8 subnormal must be converted to f8 zero + // All values <= min f8 subnormal/2 must be converted to f8 zero + EXPECT_EQ(c(+7.62939453125e-06f * 0.6f), 0b0'0000'001); + EXPECT_EQ(c(-7.62939453125e-06f * 0.6f), 0b1'0000'001); constexpr int src_min_subnorm_exp = -(ck_tile::numeric_traits::bias + ck_tile::numeric_traits::mant - 1); constexpr int dst_min_subnorm_exp = diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt new file mode 100644 index 0000000000..b17d682560 --- /dev/null +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -0,0 +1,31 @@ +# Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt +if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9") + return() +endif() + +set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") +set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") + +add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp) +target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES}) + +add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp) +target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES}) + +add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp) +target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES}) + +add_gtest_executable(test_ck_tile_fmha_fwd_fp16 test_fmha_fwd_fp16.cpp) +target_link_libraries(test_ck_tile_fmha_fwd_fp16 PRIVATE ${FMHA_FWD_INSTANCES}) + +add_gtest_executable(test_ck_tile_fmha_fwd_fp8 test_fmha_fwd_fp8.cpp) +target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES}) + +add_custom_target(test_ck_tile_fmha + DEPENDS + test_ck_tile_fmha_bwd_bf16 + test_ck_tile_fmha_bwd_fp16 + test_ck_tile_fmha_fwd_bf16 + test_ck_tile_fmha_fwd_fp16 + test_ck_tile_fmha_fwd_fp8 +) diff --git a/test/ck_tile/fmha/test_fmha_bwd.inc b/test/ck_tile/fmha/test_fmha_bwd.inc new file mode 100644 index 0000000000..1ad321ec99 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_bwd.inc @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +using ::testing::Bool; +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +// Random seed used for initializing input tensors. 0 for non-deterministic seed +CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456) + +// Whether to run long tests (from smoke_test_fwd.sh) +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS) + +#define CHECK_RESULT(result) \ + do \ + { \ + if(result == bwd_result::no_instance) \ + GTEST_SKIP() << "No instance for current parameters"; \ + ASSERT_EQ(result, bwd_result::success); \ + } while(0) + +const ck_tile::stream_config stream_config{ + nullptr, // stream_id_ + false, // time_kernel_ + 1, // log_level_ + 0, // cold_niters_ + 1, // nrepeat_ + true, // is_gpu_timer_ + false, // flush_cache_ + 1, // rotating_count_ +}; + +#define COMMON_ARGS \ + init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ + stream_config + +auto EnableTestIf(bool condition) +{ + return ValuesIn(condition ? std::vector{true} : std::vector{}); +} + +class AllLong : public TestWithParam, + bool, + mode_enum, + std::string, + float, + std::tuple>> +{ +}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong); + +// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaBwd, + AllLong, + Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))), + HDimValues, + Bool(), + ModeValues, + Values("n", "a"), + Values(0.0f, 0.2f), + Values(std::tuple{1, 4, 2, 259, -1, "0"}, + std::tuple{2, 2, -1, 516, 253, "0"}, + std::tuple{1, 4, 1, 500, 251, "1"}, + std::tuple{1, 2, -1, 900, 258, "2"}, + std::tuple{2, 1, -1, 987, 219, "t:128,30"}, + std::tuple{2, 3, 1, 244, 499, "b:4,35"}))); + +TEST_P(AllLong, Test) +{ + auto [_, hdims, perm, mode, bias_str, p_drop, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_bwd_run(mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + hdim_q, + hdim_v, + perm, // i_perm + perm, // o_perm + 0, // scale + bias_str, // bias_str + false, // use_dbias + p_drop, // p_drop + 123, // drop_seed + 1024, // drop_offset + true, // drop_prefs + mask_str, // mask_str + false, // deterministic + COMMON_ARGS); + CHECK_RESULT(result); +} + +class HDimPadding + : public TestWithParam, + bool, + mode_enum, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + HDimPadding, + Combine(Values(std::tuple{24, 48}, + std::tuple{120, 160}, + std::tuple{256, 108}, + std::tuple{40, 64}), + Bool(), + ModeValues, + Values(std::tuple{1, 4, 2, 480, -1, "0"}, + std::tuple{2, 2, -1, 300, 400, "t:64,64"}, + std::tuple{1, 4, 1, 512, 201, "1"}, + std::tuple{1, 2, -1, 900, 256, "0"}, + std::tuple{2, 1, -1, 256, 256, "1"}))); + +TEST_P(HDimPadding, Test) +{ + auto [hdims, perm, mode, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_bwd_run(mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + hdim_q, + hdim_v, + perm, // i_perm + perm, // o_perm + 0, // scale + "n", // bias_str + false, // use_dbias + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + false, // deterministic + COMMON_ARGS); + CHECK_RESULT(result); +} + +class ElementwiseBias + : public TestWithParam, + bool, + mode_enum, + std::string, + bool, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + ElementwiseBias, + Combine(HDimValues, + Bool(), // layouts of bias and dbias are controlled by i_perm + ModeValues, + Values("e:0", "e:1", "e:2"), + Bool(), + Values(std::tuple{1, 4, 2, 1024, 100, "0"}, + std::tuple{3, 2, -1, 128, 256, "2"}, + std::tuple{2, 2, -1, 130, 499, "t:50,64"}))); + +TEST_P(ElementwiseBias, Test) +{ + auto [hdims, i_perm, mode, bias_str, use_dbias, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_bwd_run(mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + hdim_q, + hdim_v, + i_perm, // i_perm + false, // o_perm + 0, // scale + bias_str, // bias_str + use_dbias, // use_dbias + 0.0f, // p_drop + 123, // drop_seed + 1024, // drop_offset + true, // drop_prefs + mask_str, // mask_str + false, // deterministic + COMMON_ARGS); + CHECK_RESULT(result); +} + +class Alibi : public TestWithParam, + mode_enum, + std::string, + std::tuple, + std::string>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + Alibi, + Combine(HDimValues, + ModeValues, + Values("a:0", "a:1"), + Values(std::tuple{1, 3, 3, 1024, 1000}, + std::tuple{3, 5, 5, 128, 256}, + std::tuple{2, 8, 4, 130, 320}), + Values("0", "t", "b", "t:50,64", "b:32,40"))); + +TEST_P(Alibi, Test) +{ + auto [hdims, mode, bias_str, dims, mask_str] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims; + + auto result = fmha_bwd_run(mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + hdim_q, + hdim_v, + true, // i_perm + true, // o_perm + 0, // scale + bias_str, // bias_str + false, // use_dbias + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + false, // deterministic + COMMON_ARGS); + CHECK_RESULT(result); +} + +class Dropout : public TestWithParam, + mode_enum, + float, + std::tuple, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + Dropout, + Combine(HDimValues, + ModeValues, + Values(0.123f, 0.5f), + Values(std::tuple{10, 123, false}, + std::tuple{34534564645, 7876878876864, true}), + Values(std::tuple{2, 6, 2, 180, 512, "0"}, + std::tuple{3, 2, 2, 256, 128, "1"}, + std::tuple{4, 2, 1, 100, 768, "2"}))); + +TEST_P(Dropout, Test) +{ + auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_bwd_run(mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + hdim_q, + hdim_v, + true, // i_perm + true, // o_perm + 0.1f, // scale + "n", // bias_str + false, // use_dbias + p_drop, // p_drop + drop_seed, // drop_seed + drop_offset, // drop_offset + drop_prefs, // drop_prefs + mask_str, // mask_str + false, // deterministic + COMMON_ARGS); + CHECK_RESULT(result); +} + +class Deterministic + : public TestWithParam, + bool, + mode_enum, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + Deterministic, + Combine(HDimValues, + Bool(), + ModeValues, + Values(std::tuple{2, 6, 2, 180, 512, "0"}, + std::tuple{3, 3, 1, 256, 128, "1"}, + std::tuple{4, 2, 2, 768, 100, "2"}))); + +TEST_P(Deterministic, Test) +{ + auto [hdims, i_perm, mode, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_bwd_run(mode, + batch, + nhead, + nhead_k, + {seqlen_q}, + {seqlen_k}, + hdim_q, + hdim_v, + i_perm, // i_perm + true, // o_perm + 0, // scale + "n", // bias_str + false, // use_dbias + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + true, // deterministic + COMMON_ARGS); + CHECK_RESULT(result); +} diff --git a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp new file mode 100644 index 0000000000..cd143e8e83 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" +#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" + +#include "gtest/gtest.h" + +using DataTypeConfig = FmhaBwdBf16; + +using ::testing::Values; +using ::testing::ValuesIn; + +const auto HDimValues = + Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +constexpr std::string init_method = "uf"; + +#include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp new file mode 100644 index 0000000000..4bb1e04ad0 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_bwd.hpp" +#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp" + +#include "gtest/gtest.h" + +using DataTypeConfig = FmhaBwdFp16; + +using ::testing::Values; +using ::testing::ValuesIn; + +const auto HDimValues = + Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +constexpr std::string init_method = "uf"; + +#include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc new file mode 100644 index 0000000000..9ff5b442b4 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -0,0 +1,628 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +using ::testing::Bool; +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +// Random seed used for initializing input tensors. 0 for non-deterministic seed +CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456) + +// Whether to run long tests (from smoke_test_fwd.sh) +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS) + +#define CHECK_RESULT(result) \ + do \ + { \ + if(result == fwd_result::no_instance) \ + GTEST_SKIP() << "No instance for current parameters"; \ + ASSERT_EQ(result, fwd_result::success); \ + } while(0) + +const ck_tile::stream_config stream_config{ + nullptr, // stream_id_ + false, // time_kernel_ + 1, // log_level_ + 0, // cold_niters_ + 1, // nrepeat_ + true, // is_gpu_timer_ + false, // flush_cache_ + 1, // rotating_count_ +}; + +// range_q, range_k, range_v, range_p, range_o, squant +#define QUANT_ARGS 1, 1, 1, 1, 1, squant + +#define COMMON_ARGS \ + init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ + stream_config + +auto EnableTestIf(bool condition) +{ + return ValuesIn(condition ? std::vector{true} : std::vector{}); +} + +class AllLong : public TestWithParam< + std::tuple, + bool, + bool, + mode_enum, + bool, + std::string, + float, + std::tuple>> +{ +}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong); + +// Test cases from example/ck_tile/01_fmha/script/smoke_test_fwd.sh + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaFwd, + AllLong, + Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))), + HDimValues, + Bool(), + IsVRowmajorValues, + ModeValues, + Bool(), + Values("n", "e", "a"), + Values(0.0f, 0.2f), + Values(std::tuple{2, 2, 1, 16, -1, 55, 256, -1, "0"}, + std::tuple{1, 3, -1, -1, -1, 100, 51, -1, "0"}, + std::tuple{2, 1, -1, 16, -1, 99, 256, -1, "1"}, + std::tuple{1, 2, 1, -1, -1, 1024, 256, -1, "2"}, + std::tuple{2, 1, -1, -1, 24, 3, 99, -1, "2"}, + std::tuple{3, 2, 1, -1, -1, 200, 520, -1, "t:128,30"}, + std::tuple{2, 1, -1, -1, -1, 99, 32, -1, "b:4,35"}, + std::tuple{1, 2, 1, -1, -1, 33, 0, -1, "2"}, + std::tuple{1, 2, 1, -1, -1, 1, 10, 32, "2"}))); + +TEST_P(AllLong, Test) +{ + auto [_, hdims, perm, is_v_rowmajor, mode, lse, bias_str, p_drop, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, hdim_q_, hdim_v_, seqlen_q, seqlen_k, seqlen_kpad, mask_str] = + dims_mask; + + hdim_q = hdim_q_ == -1 ? hdim_q : hdim_q_; + hdim_v = hdim_v_ == -1 ? hdim_v : hdim_v_; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + 0, // seqlen_knew + {seqlen_kpad}, // seqlen_kpads + 0, // rotary_dim + perm, // i_perm + perm, // o_perm + 0, // scale_s + 0, // logits_soft_cap + is_v_rowmajor, // is_v_rowmajor + lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + bias_str, // bias_str + p_drop, // p_drop + 123, // drop_seed + 1024, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +class HDimPadding + : public TestWithParam, + bool, + bool, + mode_enum, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, + HDimPadding, + Combine(Values(std::tuple{24, 48}, + std::tuple{120, 160}, + std::tuple{256, 108}, + std::tuple{40, 64}), + Bool(), + IsVRowmajorValues, + ModeValues, + Values(std::tuple{1, 4, 2, 480, -1, -1, "0"}, + std::tuple{2, 2, -1, 300, 400, 512, "t:64,64"}, + std::tuple{1, 4, 1, 512, 201, 256, "1"}, + std::tuple{1, 2, -1, 900, 256, -1, "0"}, + std::tuple{2, 1, -1, 256, 256, -1, "1"}))); + +TEST_P(HDimPadding, Test) +{ + auto [hdims, perm, is_v_rowmajor, mode, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, seqlen_kpad, mask_str] = dims_mask; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + 0, // seqlen_knew + {seqlen_kpad}, // seqlen_kpads + 0, // rotary_dim + perm, // i_perm + perm, // o_perm + 0, // scale_s + 0, // logits_soft_cap + is_v_rowmajor, // is_v_rowmajor + def_lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +class ElementwiseBias + : public TestWithParam, + bool, + mode_enum, + std::string, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, + ElementwiseBias, + Combine(HDimValues, + Bool(), // layout of bias is controlled by i_perm + ModeValues, + Values("e:0", "e:1", "e:2"), + Values(std::tuple{1, 4, 2, 1024, 100, "0"}, + std::tuple{3, 2, -1, 128, 256, "2"}, + std::tuple{2, 2, -1, 130, 499, "t:50,64"}))); + +TEST_P(ElementwiseBias, Test) +{ + auto [hdims, i_perm, mode, bias_str, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + 0, // seqlen_knew + {-1}, // seqlen_kpads + 0, // rotary_dim + i_perm, // i_perm + false, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, // is_v_rowmajor + def_lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + bias_str, // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +class Alibi : public TestWithParam, + mode_enum, + std::string, + std::tuple, + std::string>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, + Alibi, + Combine(HDimValues, + ModeValues, + Values("a:0", "a:1"), + Values(std::tuple{1, 3, 3, 1024, 1000}, + std::tuple{3, 5, 5, 128, 256}, + std::tuple{2, 8, 2, 300, 355}), + Values("0", "t", "b", "t:50,64", "b:32,40"))); + +TEST_P(Alibi, Test) +{ + auto [hdims, mode, bias_str, dims, mask_str] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + 0, // seqlen_knew + {-1}, // seqlen_kpads + 0, // rotary_dim + true, // i_perm + true, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, // is_v_rowmajor + def_lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + bias_str, // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +class Dropout : public TestWithParam, + mode_enum, + float, + std::tuple, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, + Dropout, + Combine(HDimValues, + ModeValues, + Values(0.123f, 0.5f), + Values(std::tuple{10, 123, false}, + std::tuple{34534564645, 7876878876864, true}), + Values(std::tuple{2, 4, 2, 280, 512, "0"}, + std::tuple{3, 2, 2, 256, 128, "1"}, + std::tuple{4, 3, 1, 100, 768, "2"}))); + +TEST_P(Dropout, Test) +{ + auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + 0, // seqlen_knew + {-1}, // seqlen_kpads + 0, // rotary_dim + false, // i_perm + false, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, // is_v_rowmajor + def_lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias_str + p_drop, // p_drop + drop_seed, // drop_seed + drop_offset, // drop_offset + drop_prefs, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +#if CK_TILE_FMHA_FWD_PAGEDKV_API + +class PagedKV : public TestWithParam, + bool, + bool, + mode_enum, + int, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, + PagedKV, + Combine(SplitKVHDimValues, + Bool(), // layouts of k and v are controlled by i_perm + IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor + ModeValues, + Values(128, 256), + Values(std::tuple{2, 3, 1, 200, 1024, "0"}, + std::tuple{3, 2, -1, 128, 768, "2"}, + std::tuple{2, 2, -1, 230, 899, "t:50,64"}))); + +TEST_P(PagedKV, Test) +{ + auto [hdims, i_perm, is_v_rowmajor, mode, page_block_size, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + 0, // seqlen_knew + {-1}, // seqlen_kpads + 0, // rotary_dim + i_perm, // i_perm + false, // o_perm + 0, // scale_s + 0, // logits_soft_cap + is_v_rowmajor, // is_v_rowmajor + def_lse, // lse + page_block_size, // page_block_size + false, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +#endif // CK_TILE_FMHA_FWD_PAGEDKV_API + +#if CK_TILE_FMHA_FWD_SPLITKV_API + +class SplitKV : public TestWithParam, + bool, + bool, + std::tuple, + int, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, + SplitKV, + Combine(SplitKVHDimValues, + Bool(), // layouts of k and v are controlled by i_perm + IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor + Values(std::tuple{mode_enum::batch, false}, + std::tuple{mode_enum::batch, true}, + std::tuple{mode_enum::group, false}), + Values(3, 4), + Values(std::tuple{4, 3, 1, 200, 1024, "0"}, + std::tuple{2, 2, -1, 512, 2000, "0"}, + std::tuple{3, 2, -1, 230, 899, "t:128,128"}))); + +TEST_P(SplitKV, Test) +{ + auto [hdims, i_perm, is_v_rowmajor, mode_use_cache_batch_idx, num_splits, dims_mask] = + GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [mode, use_cache_batch_idx] = mode_use_cache_batch_idx; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + 0, // seqlen_knew + {-1}, // seqlen_kpads + 0, // rotary_dim + i_perm, // i_perm + false, // o_perm + 0, // scale_s + 0, // logits_soft_cap + is_v_rowmajor, // is_v_rowmajor + def_lse, // lse + 0, // page_block_size + use_cache_batch_idx, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + true, // is_rotary_interleaved + num_splits, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +#endif // CK_TILE_FMHA_FWD_SPLITKV_API + +#if CK_TILE_FMHA_FWD_APPENDKV_API + +class AppendKV : public TestWithParam, + bool, + bool, + std::tuple, + int, + std::tuple>> +{ +}; + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaFwd, + AppendKV, + Combine(AppendKVHDimValues, + Bool(), // layouts of k and v are controlled by i_perm + IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor + ValuesIn({std::tuple{0, true}, std::tuple{0, false}, std::tuple{128, false}}), + Values(1, 64, -1), + Values(std::tuple{3, 3, -1, 60, 129, "t:32,32"}, + std::tuple{3, 2, 2, 256, 256, "0"}, + std::tuple{2, 3, 1, 264, 265, "1"}, + std::tuple{4, 4, 2, 71, 64, "1"}))); + +TEST_P(AppendKV, Test) +{ + auto [hdims, + i_perm, + is_v_rowmajor, + page_block_size_use_cache_batch_idx, + seqlen_knew, + dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [page_block_size, use_cache_batch_idx] = page_block_size_use_cache_batch_idx; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew; + + auto result = fmha_fwd_run(mode_enum::batch, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + seqlen_knew, // seqlen_knew + {-1}, // seqlen_kpads + 0, // rotary_dim + i_perm, // i_perm + true, // o_perm + 0, // scale_s + 0, // logits_soft_cap + is_v_rowmajor, // is_v_rowmajor + def_lse, // lse + page_block_size, // page_block_size + use_cache_batch_idx, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + false, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +class AppendKVRoPE + : public TestWithParam, + bool, + bool, + std::tuple, + int, + std::tuple>> +{ +}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE); + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, + AppendKVRoPE, + Combine(EnableTestIf(!std::is_same_v), + AppendKVHDimValues, + Bool(), // layouts of k and v are controlled by i_perm + IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor + Values(std::tuple{0, false}, + std::tuple{16, true}, + std::tuple{32, false}, + std::tuple{-1, true}), + Values(16, 50, -1), + Values(std::tuple{2, 3, -1, 60, 129, "t:32,32"}, + std::tuple{1, 2, 1, 128, 55, "0"}, + std::tuple{3, 4, 2, 72, 128, "1"}))); + +TEST_P(AppendKVRoPE, Test) +{ + auto [_, hdims, i_perm, is_v_rowmajor, rotary, seqlen_knew, dims_mask] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [rotary_dim, is_rotary_interleaved] = rotary; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + rotary_dim = rotary_dim == -1 ? hdim_q : rotary_dim; + seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew; + + auto result = fmha_fwd_run(mode_enum::batch, + batch, + nhead, + nhead_k, + {adjust_seqlen(seqlen_q)}, + {adjust_seqlen(seqlen_k)}, + hdim_q, + hdim_v, + seqlen_knew, // seqlen_knew + {-1}, // seqlen_kpads + rotary_dim, // rotary_dim + i_perm, // i_perm + true, // o_perm + 0, // scale_s + 0, // logits_soft_cap + is_v_rowmajor, // is_v_rowmajor + true, // lse + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + QUANT_ARGS, + is_rotary_interleaved, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} + +#endif // CK_TILE_FMHA_FWD_APPENDKV_API diff --git a/test/ck_tile/fmha/test_fmha_fwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_fwd_bf16.cpp new file mode 100644 index 0000000000..fbc6449a6a --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_fwd_bf16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" +#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" + +#include "gtest/gtest.h" + +#include +#include + +using ::testing::Values; + +using DataTypeConfig = FmhaFwdBf16; + +const auto HDimValues = Values(std::tuple{32, -1}, + std::tuple{64, -1}, + std::tuple{96, 128}, + std::tuple{128, -1}, + std::tuple{192, 128}, + std::tuple{192, -1}, + std::tuple{256, -1}); + +const auto SplitKVHDimValues = Values(std::tuple{32, -1}, + std::tuple{64, -1}, + std::tuple{96, -1}, + std::tuple{128, -1}, + std::tuple{256, -1}); + +const auto AppendKVHDimValues = + Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +const auto IsVRowmajorValues = Values(false, true); + +const bool squant = false; +const std::string init_method = "uf"; +const bool def_lse = true; +const bool def_is_v_rowmajor = true; + +int adjust_seqlen(int seqlen) { return seqlen; } + +#include "test_fmha_fwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp16.cpp new file mode 100644 index 0000000000..abc2c44726 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_fwd_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" +#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" + +#include "gtest/gtest.h" + +#include +#include + +using ::testing::Values; + +using DataTypeConfig = FmhaFwdFp16; + +const auto HDimValues = Values(std::tuple{32, -1}, + std::tuple{64, -1}, + std::tuple{96, 128}, + std::tuple{128, -1}, + std::tuple{192, 128}, + std::tuple{192, -1}, + std::tuple{256, -1}); + +const auto SplitKVHDimValues = Values(std::tuple{32, -1}, + std::tuple{64, -1}, + std::tuple{96, -1}, + std::tuple{128, -1}, + std::tuple{256, -1}); + +const auto AppendKVHDimValues = + Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); + +const auto ModeValues = Values(mode_enum::batch, mode_enum::group); + +const auto IsVRowmajorValues = Values(false, true); + +const bool squant = false; +const std::string init_method = "uf"; +const bool def_lse = true; +const bool def_is_v_rowmajor = true; + +int adjust_seqlen(int seqlen) { return seqlen; } + +#include "test_fmha_fwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp new file mode 100644 index 0000000000..46ed8f4125 --- /dev/null +++ b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "example/ck_tile/01_fmha/fmha_fwd.hpp" +#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp" + +#include "gtest/gtest.h" + +#include +#include + +using ::testing::Values; + +using DataTypeConfig = FmhaFwdFp8; + +// Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such +// instances are added), however the corresponding tests are not disabled (they will be skipped) +// in case such instances will be added in the future. + +const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); + +const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); + +const auto AppendKVHDimValues = + Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); + +// There are no fp8 instances with seqlen padding (mode_enum::group requires it) +const auto ModeValues = Values(mode_enum::batch); + +const auto IsVRowmajorValues = Values(false); + +const bool squant = true; +const std::string init_method = "ufq"; +const bool def_lse = false; +const bool def_is_v_rowmajor = false; + +int adjust_seqlen(int seqlen) +{ + // There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests + return ck_tile::integer_least_multiple(seqlen, 128); +} + +#include "test_fmha_fwd.inc"