mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK_TILE] Add gtests for FMHA (#2744)
* Improve random number generation * use different seed for each input (Q, K, V...); * use deterministic generation of: * seqstart_q/k (for group mode); * block_table (for paged-kvcahe); * cache_batch_idx (for kvcache); * Extract arg_parser-related code from run functions to use them as tests * Split examples into main programs and fmha runners, build instances separately * Add dummy tests that use instances and runners * Fix a missed corner case of f32->f8 conversion When value if < min f8 denormal but > min f8 denormal / 2, it must be rounded to min f8 denormal (i.e. 0b1), not to 0. * Fix incorrect fp8 scales for P and O in validation code DataTypeConfig was incorrectly compared with fp8_t. * Add host generation of dropout random values and use it for validation Previously host validation (reference_batched_dropout) used random numbers generated by BlockDropout of the kernel, meaning that incorrect generation on device (bad distribution, repeated numbers, too many zeros, etc.) would not trigger any validation errors. * Implement tests from smoke_test_bwd.sh * Return result as enum to distinguish failure and missing instance * Add tests for bwd features: bias, alibi, dropout * Implement tests from smoke_test_fwd.sh * Pass seqlen_q/k as vectors to fwd and bwd runners * Add tests for fwd features: bias, alibi, dropout * Add tests for pagedkv and splitkv * Fix conditions when to use splitkv and pagedkv kernels splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size). In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1. In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance. * Add tests for appendkv * Use is_v_rowmajor = true because there are no instances with column layout anymore * Split public and private compile options for instances Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API. * Improve parsing validation in bias and mask * Pass bias as string for consistency with mask * Catch parsing and other exceptions * Add bwd test for deterministic flag * Initialize fp8 tensors (-init=ufq) similarly to uf * Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null seqlen_k cannot be used to determine padding when seqlen_k_ptr is provided. The actual seqlen_k is taken from seqlen_k_ptr[b]. Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr may contain arbitrary values. In the example or tests this produces incorrect results with appendkv (for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8). * Fix use_pagedkv value when kvcache = true but page_block_size = 0 In this case block_table_ptr is nullptr which is accessed in the kernel. * Clean up bwd tests * Unify fwd tests for f16/bf16 and fp8 * Use better explicit instantiation declaration for fmha_bwd<2> * Use the same seed for all tests, allow to override it with env variable * Undo clang-format of one irrelevant file For some reason my local clang-format-18 and the one in CI work differently. * Do not build instances and tests on unsupported archs * Build instance libraries as OBJECT library * CI: Enable sccache for HIP There are source files with LANGUAGE HIP, they need -DCMAKE_HIP_COMPILER_LAUNCHER=sccache * Add tests to REGRESSION_TESTS * Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0 The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf. * Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options The instances don't actually depend on them, only examples and tests do. Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS without recompiling instances that are already in ccache. * Fix formatting and names
This commit is contained in:
2
Jenkinsfile
vendored
2
Jenkinsfile
vendored
@@ -321,7 +321,7 @@ def cmake_build(Map conf=[:]){
|
||||
${redis_pre_setup_cmd}
|
||||
""")
|
||||
sh cmd1
|
||||
setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache " + setup_args
|
||||
setup_args = " -DCMAKE_HIP_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache " + setup_args
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not connect to redis server: ${err.getMessage()}. will not use sccache."
|
||||
|
||||
@@ -184,7 +184,7 @@ hours to 1-2 minutes. In order to invoke sccache, you need to run:
|
||||
then add the following flags to the cmake command line:
|
||||
|
||||
```bash
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache
|
||||
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache
|
||||
```
|
||||
|
||||
You may need to clean up the build folder and repeat the cmake and make steps in order to take
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 archs are supported by FMHA
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(BUILD_TESTING)
|
||||
# Build instances of all APIs for tests
|
||||
set(FMHA_FWD_ENABLE_APIS "all")
|
||||
endif()
|
||||
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
@@ -77,72 +89,100 @@ add_custom_command(
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
|
||||
|
||||
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}")
|
||||
add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
|
||||
target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}")
|
||||
add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
|
||||
target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS)
|
||||
set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS)
|
||||
set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS)
|
||||
set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
|
||||
# NOTE: this is dangerous since will change the whole kernel to flush denormals
|
||||
# WIP with compiler team for an exp2 intrinsic..., then remove this
|
||||
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
|
||||
set(FMHA_FWD_FAST_EXP2 true)
|
||||
set(FMHA_FWD_FAST_EXP2 ON)
|
||||
endif()
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
|
||||
|
||||
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
|
||||
# conditionally enable call to the fwd_splitkv API in fmha_fwd example and tests
|
||||
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
|
||||
# conditionally enable call to the fwd_appendkv API in fmha_fwd example and tests
|
||||
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example
|
||||
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example and tests
|
||||
if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally specify the use of OCP_FP8
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
target_compile_options(${FMHA_FWD_INSTANCES}
|
||||
PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS}
|
||||
INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS})
|
||||
target_compile_options(${FMHA_BWD_INSTANCES}
|
||||
PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS}
|
||||
INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS})
|
||||
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
|
||||
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
|
||||
# not using add_example_executable() to add this target, since we don't want this to be included in
|
||||
# "make all/install/check"
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
|
||||
target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
# not using add_example_executable() to add this target, since we don't want this to be included in
|
||||
# "make all/install/check"
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
|
||||
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# add fmha_fwd_v3 example
|
||||
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -63,31 +63,45 @@ struct bias_info
|
||||
static bias_info decode(std::string str)
|
||||
{
|
||||
bias_info info{bias_enum::no_bias, 0};
|
||||
if(str == "0" || str == "n")
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "e" || t == "elementwise")
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
info.rank_info = std::stoi(v);
|
||||
if(info.rank_info < 0 || info.rank_info > 2)
|
||||
throw std::invalid_argument("invalid bias rank: " + str);
|
||||
}
|
||||
else if(t == "a" || t == "alibi")
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
info.rank_info = std::stoi(v);
|
||||
if(info.rank_info < 0 || info.rank_info > 1)
|
||||
throw std::invalid_argument("invalid bias rank: " + str);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid bias value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0" || str == "n")
|
||||
{
|
||||
info.type = bias_enum::no_bias;
|
||||
}
|
||||
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
|
||||
str.compare(0, 11, "elementwise") == 0)
|
||||
else if(str == "1" || str == "e" || str == "elementwise")
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
}
|
||||
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
|
||||
str.compare(0, 5, "alibi") == 0)
|
||||
else if(str == "2" || str == "a" || str == "alibi")
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
info.type = bias_enum::alibi;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid bias value: " + str);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
@@ -347,8 +347,8 @@ class FmhaFwdSplitKVApiTrait:
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
@@ -189,8 +189,8 @@ class FmhaFwdApiTrait:
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
|
||||
183
example/ck_tile/01_fmha/example_fmha_bwd.cpp
Normal file
183
example/ck_tile/01_fmha/example_fmha_bwd.cpp
Normal file
@@ -0,0 +1,183 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "fmha_bwd_runner.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "whether do CPU validation or not")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k, -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("dbias", "0", "output bias gradient or not")
|
||||
.insert("prec", "fp16", "data type. fp16 or bf16")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n uf or 1 - uniform random float"
|
||||
"\n tf or 2 - trig float")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for dropout random number generator")
|
||||
.insert("drop_offset", "0", "offset for dropout random number generator")
|
||||
.insert(
|
||||
"drop_prefs",
|
||||
"0",
|
||||
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("deterministic",
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
|
||||
"will not be used")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
auto run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
float scale = arg_parser.get_float("scale");
|
||||
std::string bias_str = arg_parser.get_str("bias");
|
||||
bool use_dbias = arg_parser.get_bool("dbias");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
auto json = arg_parser.get_int("json") == 1
|
||||
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
|
||||
: std::nullopt;
|
||||
|
||||
return fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
scale,
|
||||
bias_str,
|
||||
use_dbias,
|
||||
p_drop,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
deterministic,
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<FmhaBwdFp16>(arg_parser) == bwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<FmhaBwdBf16>(arg_parser) == bwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::invalid_argument& e)
|
||||
{
|
||||
std::cerr << "Invalid argument: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
253
example/ck_tile/01_fmha/example_fmha_fwd.cpp
Normal file
253
example/ck_tile/01_fmha/example_fmha_fwd.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "fmha_fwd_runner.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_k",
|
||||
"-1",
|
||||
"seqlen_k (including new key/value), -1 means equal to s\n"
|
||||
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
|
||||
"(group mode)")
|
||||
.insert("s_knew",
|
||||
"0",
|
||||
"seqlen_k for new key/value, 0 means not to use this at all; "
|
||||
"-1 to choose s_knew in [1, s] randomly.")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed, same as xformer kv_padding,\n"
|
||||
"must be greater than or equal to s_k")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s",
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
"note when squant=1, this value will be modified by range_q/k")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
|
||||
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
|
||||
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
|
||||
.insert("squant",
|
||||
"auto",
|
||||
"if using static quantization fusion or not. auto: fp8 will default use squant, "
|
||||
"other will not\n"
|
||||
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
|
||||
"P and O.\n"
|
||||
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
|
||||
"range_p, range_o")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
|
||||
.insert("lse", "0", "0 not store lse, 1 store lse")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for dropout random number generator")
|
||||
.insert("drop_offset", "0", "offset for dropout random number generator")
|
||||
.insert(
|
||||
"drop_prefs",
|
||||
"0",
|
||||
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
|
||||
.insert("num_splits",
|
||||
"1",
|
||||
"# of splits for key/value. 0 to determine actual number by heuristic")
|
||||
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
|
||||
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
auto run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
auto seqlen_qs = arg_parser.get_int_vec("s");
|
||||
auto seqlen_ks = arg_parser.get_int_vec("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
|
||||
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
|
||||
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
float scale_s = arg_parser.get_float("scale_s");
|
||||
float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
|
||||
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
|
||||
std::string bias_str = arg_parser.get_str("bias");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
bool squant = [&]() {
|
||||
if(arg_parser.get_str("squant") == "auto")
|
||||
return std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
|
||||
else
|
||||
return arg_parser.get_bool("squant");
|
||||
}();
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
auto json = arg_parser.get_int("json") == 1
|
||||
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
|
||||
: std::nullopt;
|
||||
|
||||
return fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs,
|
||||
seqlen_ks,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_knew,
|
||||
seqlen_kpads,
|
||||
rotary_dim,
|
||||
i_perm,
|
||||
o_perm,
|
||||
scale_s,
|
||||
logits_soft_cap,
|
||||
is_v_rowmajor,
|
||||
lse,
|
||||
page_block_size,
|
||||
use_cache_batch_idx,
|
||||
bias_str,
|
||||
p_drop,
|
||||
drop_seed,
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
range_q,
|
||||
range_k,
|
||||
range_v,
|
||||
range_p,
|
||||
range_o,
|
||||
squant,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<FmhaFwdFp16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::invalid_argument& e)
|
||||
{
|
||||
std::cerr << "Invalid argument: " << e.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
@@ -17,91 +18,13 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
enum class bwd_result
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("dbias", "0", "output bias gradient or not")
|
||||
.insert("prec", "fp16", "data type. fp16 or bf16")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("deterministic",
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
|
||||
"will not be used")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
success,
|
||||
failure,
|
||||
invalid_args,
|
||||
no_instance,
|
||||
};
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataTypeConfig>
|
||||
@@ -125,57 +48,82 @@ auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
extern template float fmha_bwd<2>(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t nhead_k,
|
||||
std::vector<ck_tile::index_t> seqlen_qs,
|
||||
std::vector<ck_tile::index_t> seqlen_ks,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
float scale,
|
||||
std::string bias_str,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset,
|
||||
bool drop_prefs,
|
||||
std::string mask_str,
|
||||
bool deterministic,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
int do_validation,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
const std::string data_type = []() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdFp16>)
|
||||
return "fp16";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdBf16>)
|
||||
return "bf16";
|
||||
else
|
||||
static_assert(false);
|
||||
}();
|
||||
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
|
||||
if(nhead % nhead_k != 0)
|
||||
{
|
||||
std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl;
|
||||
return false;
|
||||
return bwd_result::invalid_args;
|
||||
}
|
||||
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}());
|
||||
auto next_seed = [&random_engine]() { return static_cast<unsigned int>(random_engine()); };
|
||||
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
|
||||
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
|
||||
|
||||
float scale = arg_parser.get_float("scale");
|
||||
if(scale == .0f)
|
||||
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
|
||||
bool use_dbias = arg_parser.get_bool("dbias");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
bias_info bias = bias_info::decode(bias_str);
|
||||
|
||||
if(use_dbias && bias.type != bias_enum::elementwise_bias)
|
||||
{
|
||||
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
|
||||
return false;
|
||||
return bwd_result::invalid_args;
|
||||
}
|
||||
std::vector<ck_tile::index_t> seqlen_kpads;
|
||||
std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) =
|
||||
generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine);
|
||||
ck_tile::ignore = seqlen_kpads;
|
||||
#if 0
|
||||
std::cout << "seqlen_qs: " << seqlen_qs << std::endl;
|
||||
std::cout << "seqlen_ks: " << seqlen_ks << std::endl;
|
||||
#endif
|
||||
|
||||
mask_info mask = mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]);
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
return false;
|
||||
return bwd_result::invalid_args;
|
||||
}
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
uint8_t p_undrop_in_uint8_t =
|
||||
@@ -188,29 +136,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
s_randval = true;
|
||||
}
|
||||
|
||||
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
int init_method = arg_parser.get_int("init");
|
||||
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
|
||||
if(*seed == 0)
|
||||
{
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
stream_warmup,
|
||||
stream_repeat,
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
|
||||
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
@@ -283,10 +210,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
|
||||
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
|
||||
// Keep it equal to or smaller than minimal bn0 of all tiles in fmha_bwd.py
|
||||
// TODO: add API for requesting kN0/nsplits/workspace_size? It is not safe to rely on internal
|
||||
// implementation details in client code.
|
||||
const ck_tile::index_t kN0 = 16;
|
||||
const ck_tile::index_t nsplits =
|
||||
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
|
||||
|
||||
@@ -331,23 +261,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
|
||||
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
|
||||
|
||||
if(init_method == 0)
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, next_seed()}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, next_seed()}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, next_seed()}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, next_seed()}(
|
||||
bias_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, next_seed()}(
|
||||
do_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
else if(init_method == "uf" || init_method == "1")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, next_seed()}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, next_seed()}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, next_seed()}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, next_seed()}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, next_seed()}(do_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
else if(init_method == "tf" || init_method == "2")
|
||||
{
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
@@ -355,6 +287,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unknown value for init argument: " << init_method << std::endl;
|
||||
return bwd_result::invalid_args;
|
||||
}
|
||||
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<AccDataType>(nhead);
|
||||
@@ -415,22 +353,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else return layout_str(iperm_) + std::string("-") + layout_str(operm_);
|
||||
};
|
||||
// clang-format on
|
||||
const std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
|
||||
const std::size_t workspace_size_in_megabytes =
|
||||
ck_tile::integer_divide_ceil(dq_acc_host.get_element_space_size_in_bytes(), 1024 * 1024);
|
||||
|
||||
std::size_t workspace_size =
|
||||
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
|
||||
|
||||
if(deterministic == 1)
|
||||
{
|
||||
std::cout << "\nDeterministic mode ON: " << workspace_size
|
||||
<< " MByte memory workspace allocated" << std::endl;
|
||||
}
|
||||
std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm)
|
||||
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
||||
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
|
||||
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
|
||||
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
|
||||
<< (deterministic ? std::string(", workspace:") +
|
||||
std::to_string(workspace_size_in_megabytes) + "MiB"
|
||||
: "")
|
||||
<< ", mask:" << mask << std::flush;
|
||||
|
||||
auto fmha_traits = fmha_bwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
@@ -443,7 +378,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
s_randval,
|
||||
deterministic};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
/// 'nhead_stride_bias' are 0.
|
||||
@@ -572,20 +506,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_seed_offset};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
const float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
return false;
|
||||
return bwd_result::no_instance;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
const float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
const float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
|
||||
<< gb_per_sec << " GB/s" << std::flush;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(!do_validation)
|
||||
@@ -635,17 +570,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
// permute
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
// clang-format on
|
||||
|
||||
// reference
|
||||
@@ -760,18 +695,40 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
const ck_tile::HostTensor<AccDataType> masked_s_host_ref = s_host_ref;
|
||||
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout_randval(
|
||||
randval_host_ref, wb, drop_seed, drop_offset);
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_result(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
randval_host_result.ForEach([&](auto& self, const auto& idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
masked_s_host_ref.ForEach([&](const auto& self, const auto& idx) {
|
||||
// Ignore all masked values in validation check
|
||||
if(std::isinf(self(idx)))
|
||||
{
|
||||
randval_host_ref(idx) = 0;
|
||||
randval_host_result(idx) = 0;
|
||||
}
|
||||
});
|
||||
bool cur_pass = ck_tile::check_err(randval_host_result,
|
||||
randval_host_ref,
|
||||
"DROPOUT RANDVAL Error: Incorrect results!");
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -783,11 +740,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
|
||||
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
|
||||
// permute
|
||||
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
|
||||
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
|
||||
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
|
||||
// clang-format on
|
||||
|
||||
q_host_refs.push_back(q_host_ref);
|
||||
@@ -816,8 +773,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
dbias_buf.SetZero();
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
@@ -855,8 +811,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
|
||||
// clang-format off
|
||||
if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// dP = dO@V x Z w/ dropout
|
||||
@@ -934,21 +890,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// permute
|
||||
if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
|
||||
if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); });
|
||||
else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); });
|
||||
if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); });
|
||||
else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); });
|
||||
|
||||
if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); });
|
||||
else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); });
|
||||
if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); });
|
||||
else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); });
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
}
|
||||
if(use_dbias)
|
||||
{
|
||||
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
@@ -994,10 +950,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
if(json)
|
||||
{
|
||||
dump_fmha_bwd_json_results(
|
||||
arg_parser.get_str("jsonfile"),
|
||||
*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
i_perm ? "true" : "false",
|
||||
@@ -1005,8 +961,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale,
|
||||
@@ -1027,30 +983,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: "mask_generic"))),
|
||||
mask.left,
|
||||
mask.right,
|
||||
workspace_size,
|
||||
workspace_size_in_megabytes,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
return pass ? bwd_result::success : bwd_result::failure;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -39,6 +39,7 @@ struct mask_info
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
@@ -54,7 +55,7 @@ struct mask_info
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = atoi(v.c_str());
|
||||
ck_tile::index_t window_size = std::stoi(v);
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
if(window_size > 0)
|
||||
@@ -71,18 +72,15 @@ struct mask_info
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
}
|
||||
else
|
||||
else if(t == "t" || t == "b" || t == "g")
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
|
||||
assert(0);
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
|
||||
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
|
||||
if(t == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
@@ -105,53 +103,45 @@ struct mask_info
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.type = mask_enum::window_generic;
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0")
|
||||
{
|
||||
tmp.type = mask_enum::no_mask;
|
||||
}
|
||||
else if(str == "1" || str == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else if(str == "2" || str == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto set_causal_top_left = [&]() {
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
auto set_causal_bottom_right = [&]() {
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
if(str == "t")
|
||||
set_causal_top_left();
|
||||
else if(str == "b")
|
||||
set_causal_bottom_right();
|
||||
else
|
||||
{
|
||||
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
|
||||
if(tmp.type == mask_enum::mask_top_left)
|
||||
{
|
||||
set_causal_top_left();
|
||||
}
|
||||
else if(tmp.type == mask_enum::mask_bottom_right)
|
||||
{
|
||||
set_causal_bottom_right();
|
||||
}
|
||||
}
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
ck_tile::index_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
@@ -168,6 +158,7 @@ struct mask_info
|
||||
}
|
||||
return area;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
@@ -28,6 +27,23 @@ std::ostream& operator<<(std::ostream& stream, mode_enum mode)
|
||||
return stream << (mode == mode_enum::batch ? "batch" : "group");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
{
|
||||
std::vector<int32_t> seqstarts = {0};
|
||||
@@ -39,12 +55,13 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
return seqstarts;
|
||||
}
|
||||
|
||||
template <typename RandomEngine>
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min = -1, // if not negative, clamp min
|
||||
int32_t seqlen_max = -1, // if not negative, clamp max
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
int32_t seqlen_min, // if not negative, clamp min
|
||||
int32_t seqlen_max, // if not negative, clamp max
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
@@ -58,7 +75,6 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
{
|
||||
using size_type = std::vector<int32_t>::size_type;
|
||||
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
|
||||
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
|
||||
|
||||
@@ -89,43 +105,31 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
return seqlens;
|
||||
}
|
||||
|
||||
std::vector<int32_t> generate_seqstarts(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min = -1,
|
||||
int32_t seqlen_max = -1,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
|
||||
}
|
||||
|
||||
// return random integer generated uniformly in range [low, high]
|
||||
template <typename Int = int>
|
||||
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
|
||||
-> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
template <typename Int = int, typename RandomEngine>
|
||||
auto randint(Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
{
|
||||
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
return dist(engine);
|
||||
return dist(random_engine);
|
||||
}
|
||||
|
||||
// return random integers generated uniformly in range [low, high]
|
||||
template <typename Int, typename ForwardIterator>
|
||||
template <typename Int, typename ForwardIterator, typename RandomEngine>
|
||||
auto randints(ForwardIterator first,
|
||||
ForwardIterator last,
|
||||
Int low,
|
||||
Int high,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
-> std::enable_if_t<std::is_integral_v<Int>>
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
|
||||
{
|
||||
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
|
||||
std::generate(first, last, [&] { return dist(engine); });
|
||||
std::generate(first, last, [&] { return dist(random_engine); });
|
||||
}
|
||||
|
||||
/*
|
||||
* decode the seqlen string from cmdline
|
||||
* generate missing values in *_val randomly when the number of values is smaller than batch
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
@@ -136,23 +140,23 @@ auto randints(ForwardIterator first,
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
decode_seqlen(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
std::string q_val,
|
||||
std::string k_val,
|
||||
std::string k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min = 0,
|
||||
bool need_append_kvcache = false,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = _S2I_(q_val);
|
||||
ck_tile::index_t k = _S2I_(k_val);
|
||||
ck_tile::index_t q = q_val[0];
|
||||
ck_tile::index_t k = k_val[0];
|
||||
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = [&] {
|
||||
@@ -166,7 +170,7 @@ decode_seqlen(mode_enum mode,
|
||||
seqlen_ks.end(),
|
||||
seqlen_k_min,
|
||||
seqlen_k_max,
|
||||
seed);
|
||||
random_engine);
|
||||
return seqlen_ks;
|
||||
}
|
||||
|
||||
@@ -187,25 +191,19 @@ decode_seqlen(mode_enum mode,
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t idx = 0;
|
||||
std::string::size_type pos_q = 0;
|
||||
std::string::size_type pos_k = 0;
|
||||
std::string::size_type pos_kp = 0;
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
while(true)
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
auto found_q = q_val.find(',', pos_q);
|
||||
auto found_k = k_val.find(',', pos_k);
|
||||
auto found_kp = k_pad_val.find(',', pos_kp);
|
||||
|
||||
ck_tile::index_t q = _S2I_(
|
||||
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
|
||||
ck_tile::index_t k = _S2I_(
|
||||
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
|
||||
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
|
||||
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
|
||||
ck_tile::index_t q = q_val[idx];
|
||||
ck_tile::index_t k =
|
||||
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
|
||||
ck_tile::index_t kp =
|
||||
k_pad_val.empty()
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
@@ -219,21 +217,13 @@ decode_seqlen(mode_enum mode,
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
idx++;
|
||||
if(found_q == std::string::npos || idx >= batch)
|
||||
{
|
||||
break;
|
||||
}
|
||||
pos_q = found_q + 1;
|
||||
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
|
||||
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
|
||||
auto rem_k =
|
||||
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
|
||||
auto rem_q =
|
||||
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
|
||||
auto rem_k = generate_seqlens(
|
||||
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
@@ -241,26 +231,14 @@ decode_seqlen(mode_enum mode,
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
#undef _S2I_
|
||||
}
|
||||
|
||||
int env_get_int(const char* var_name, int default_int)
|
||||
{
|
||||
char* v = getenv(var_name);
|
||||
int r = default_int;
|
||||
if(v)
|
||||
r = std::atoi(v);
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename RandomAccessIterator, typename Int>
|
||||
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
|
||||
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
|
||||
RandomAccessIterator last,
|
||||
Int value,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
std::iota(first, last, value);
|
||||
|
||||
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::shuffle(first, last, engine);
|
||||
std::shuffle(first, last, random_engine);
|
||||
}
|
||||
|
||||
@@ -399,9 +399,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
}
|
||||
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
|
||||
// The value is <= than min f8 denormal/2 and results in zero (the early exit also prevents
|
||||
// an undefined behavior of bit shifts >= type width).
|
||||
if(exponent_diff > DstT_mant)
|
||||
if(exponent_diff > DstT_mant + 1)
|
||||
{
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_masking.hpp"
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RandValOutputDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
|
||||
index_t batch,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
|
||||
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
|
||||
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
|
||||
|
||||
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
|
||||
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
|
||||
// different warp gemms (16x16 or 32x32).
|
||||
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
|
||||
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
|
||||
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
|
||||
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
|
||||
// C j: (lane % 32)
|
||||
// With SFactor = 2 it becomes:
|
||||
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
|
||||
// C j: (lane % 32)
|
||||
|
||||
constexpr index_t max_warp_size = 64;
|
||||
constexpr index_t warp_gemm_mn = 32;
|
||||
|
||||
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
|
||||
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
|
||||
|
||||
auto f = [&](index_t i_h, index_t row, index_t col) {
|
||||
uint2 rowcol = make_uint2(row, col);
|
||||
for(index_t lane = 0; lane < max_warp_size; lane++)
|
||||
{
|
||||
philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane);
|
||||
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
for(auto r = 0; r < 16; r++)
|
||||
{
|
||||
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
|
||||
index_t j = (lane % 32);
|
||||
index_t m = row * warp_gemm_mn + i;
|
||||
index_t n = col * warp_gemm_mn + j;
|
||||
|
||||
if(m < real_seqlen_q && n < real_seqlen_k)
|
||||
{
|
||||
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -611,7 +611,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
bool squant,
|
||||
const std::string& bais,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
@@ -636,7 +636,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("lse", lse);
|
||||
ADD_KEY_VALUE("squant", squant);
|
||||
ADD_KEY_VALUE("bias", bais);
|
||||
ADD_KEY_VALUE("bias", bias);
|
||||
ADD_KEY_VALUE("vlayout", vlayout);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
|
||||
@@ -38,6 +38,11 @@ set(REGRESSION_TESTS
|
||||
test_conv_tensor_rearrange
|
||||
test_gemm_mx
|
||||
test_ck_tile_batched_transpose
|
||||
test_ck_tile_fmha_bwd_bf16
|
||||
test_ck_tile_fmha_bwd_fp16
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
)
|
||||
|
||||
function(add_test_executable TEST_NAME)
|
||||
|
||||
@@ -25,3 +25,4 @@ add_subdirectory(utility)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(epilogue)
|
||||
add_subdirectory(atomic_add_op)
|
||||
add_subdirectory(fmha)
|
||||
|
||||
@@ -94,7 +94,9 @@ TYPED_TEST(ConvertTest, ToFp8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'0000'000);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+0.001953125f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-0.001953125f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
@@ -176,7 +178,9 @@ TYPED_TEST(ConvertTest, ToFp8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'0000'000);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+0.0009765625f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-0.0009765625f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
@@ -282,7 +286,9 @@ TYPED_TEST(ConvertTest, ToBf8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b1'00000'00);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+1.52587890625e-05f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-1.52587890625e-05f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
@@ -373,7 +379,9 @@ TYPED_TEST(ConvertTest, ToBf8)
|
||||
EXPECT_EQ(c(+ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
EXPECT_EQ(c(-ck_tile::numeric<SrcT>::denorm_min()), 0b0'00000'00);
|
||||
|
||||
// All values smaller than min f8 subnormal must be converted to f8 zero
|
||||
// All values <= min f8 subnormal/2 must be converted to f8 zero
|
||||
EXPECT_EQ(c(+7.62939453125e-06f * 0.6f), 0b0'0000'001);
|
||||
EXPECT_EQ(c(-7.62939453125e-06f * 0.6f), 0b1'0000'001);
|
||||
constexpr int src_min_subnorm_exp =
|
||||
-(ck_tile::numeric_traits<SrcT>::bias + ck_tile::numeric_traits<SrcT>::mant - 1);
|
||||
constexpr int dst_min_subnorm_exp =
|
||||
|
||||
31
test/ck_tile/fmha/CMakeLists.txt
Normal file
31
test/ck_tile/fmha/CMakeLists.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
# Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt
|
||||
if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
|
||||
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_fp16 test_fmha_fwd_fp16.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_fp16 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_gtest_executable(test_ck_tile_fmha_fwd_fp8 test_fmha_fwd_fp8.cpp)
|
||||
target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES})
|
||||
|
||||
add_custom_target(test_ck_tile_fmha
|
||||
DEPENDS
|
||||
test_ck_tile_fmha_bwd_bf16
|
||||
test_ck_tile_fmha_bwd_fp16
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
)
|
||||
344
test/ck_tile/fmha/test_fmha_bwd.inc
Normal file
344
test/ck_tile/fmha/test_fmha_bwd.inc
Normal file
@@ -0,0 +1,344 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
using ::testing::Bool;
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
|
||||
// Whether to run long tests (from smoke_test_fwd.sh)
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
||||
|
||||
#define CHECK_RESULT(result) \
|
||||
do \
|
||||
{ \
|
||||
if(result == bwd_result::no_instance) \
|
||||
GTEST_SKIP() << "No instance for current parameters"; \
|
||||
ASSERT_EQ(result, bwd_result::success); \
|
||||
} while(0)
|
||||
|
||||
const ck_tile::stream_config stream_config{
|
||||
nullptr, // stream_id_
|
||||
false, // time_kernel_
|
||||
1, // log_level_
|
||||
0, // cold_niters_
|
||||
1, // nrepeat_
|
||||
true, // is_gpu_timer_
|
||||
false, // flush_cache_
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
stream_config
|
||||
|
||||
auto EnableTestIf(bool condition)
|
||||
{
|
||||
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
|
||||
}
|
||||
|
||||
class AllLong : public TestWithParam<std::tuple<bool,
|
||||
std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
float,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
||||
|
||||
// Test cases from example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaBwd,
|
||||
AllLong,
|
||||
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
|
||||
HDimValues,
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values("n", "a"),
|
||||
Values(0.0f, 0.2f),
|
||||
Values(std::tuple{1, 4, 2, 259, -1, "0"},
|
||||
std::tuple{2, 2, -1, 516, 253, "0"},
|
||||
std::tuple{1, 4, 1, 500, 251, "1"},
|
||||
std::tuple{1, 2, -1, 900, 258, "2"},
|
||||
std::tuple{2, 1, -1, 987, 219, "t:128,30"},
|
||||
std::tuple{2, 3, 1, 244, 499, "b:4,35"})));
|
||||
|
||||
TEST_P(AllLong, Test)
|
||||
{
|
||||
auto [_, hdims, perm, mode, bias_str, p_drop, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
false, // use_dbias
|
||||
p_drop, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
true, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class HDimPadding
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
HDimPadding,
|
||||
Combine(Values(std::tuple{24, 48},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
std::tuple{40, 64}),
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values(std::tuple{1, 4, 2, 480, -1, "0"},
|
||||
std::tuple{2, 2, -1, 300, 400, "t:64,64"},
|
||||
std::tuple{1, 4, 1, 512, 201, "1"},
|
||||
std::tuple{1, 2, -1, 900, 256, "0"},
|
||||
std::tuple{2, 1, -1, 256, 256, "1"})));
|
||||
|
||||
TEST_P(HDimPadding, Test)
|
||||
{
|
||||
auto [hdims, perm, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class ElementwiseBias
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
bool,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
ElementwiseBias,
|
||||
Combine(HDimValues,
|
||||
Bool(), // layouts of bias and dbias are controlled by i_perm
|
||||
ModeValues,
|
||||
Values("e:0", "e:1", "e:2"),
|
||||
Bool(),
|
||||
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
||||
std::tuple{3, 2, -1, 128, 256, "2"},
|
||||
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
|
||||
|
||||
TEST_P(ElementwiseBias, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, bias_str, use_dbias, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
use_dbias, // use_dbias
|
||||
0.0f, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
true, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
std::string,
|
||||
std::tuple<int, int, int, int, int>,
|
||||
std::string>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Alibi,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values("a:0", "a:1"),
|
||||
Values(std::tuple{1, 3, 3, 1024, 1000},
|
||||
std::tuple{3, 5, 5, 128, 256},
|
||||
std::tuple{2, 8, 4, 130, 320}),
|
||||
Values("0", "t", "b", "t:50,64", "b:32,40")));
|
||||
|
||||
TEST_P(Alibi, Test)
|
||||
{
|
||||
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale
|
||||
bias_str, // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
float,
|
||||
std::tuple<uint64_t, uint64_t, bool>,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Dropout,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values(0.123f, 0.5f),
|
||||
Values(std::tuple{10, 123, false},
|
||||
std::tuple{34534564645, 7876878876864, true}),
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 2, 2, 256, 128, "1"},
|
||||
std::tuple{4, 2, 1, 100, 768, "2"})));
|
||||
|
||||
TEST_P(Dropout, Test)
|
||||
{
|
||||
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0.1f, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
p_drop, // p_drop
|
||||
drop_seed, // drop_seed
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
false, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Deterministic
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
Deterministic,
|
||||
Combine(HDimValues,
|
||||
Bool(),
|
||||
ModeValues,
|
||||
Values(std::tuple{2, 6, 2, 180, 512, "0"},
|
||||
std::tuple{3, 3, 1, 256, 128, "1"},
|
||||
std::tuple{4, 2, 2, 768, 100, "2"})));
|
||||
|
||||
TEST_P(Deterministic, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_bwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{seqlen_q},
|
||||
{seqlen_k},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale
|
||||
"n", // bias_str
|
||||
false, // use_dbias
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
true, // deterministic
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
21
test/ck_tile/fmha/test_fmha_bwd_bf16.cpp
Normal file
21
test/ck_tile/fmha/test_fmha_bwd_bf16.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using DataTypeConfig = FmhaBwdBf16;
|
||||
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
const auto HDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr std::string init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
21
test/ck_tile/fmha/test_fmha_bwd_fp16.cpp
Normal file
21
test/ck_tile/fmha/test_fmha_bwd_fp16.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using DataTypeConfig = FmhaBwdFp16;
|
||||
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
const auto HDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
constexpr std::string init_method = "uf";
|
||||
|
||||
#include "test_fmha_bwd.inc"
|
||||
628
test/ck_tile/fmha/test_fmha_fwd.inc
Normal file
628
test/ck_tile/fmha/test_fmha_fwd.inc
Normal file
@@ -0,0 +1,628 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
using ::testing::Bool;
|
||||
using ::testing::Combine;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::ValuesIn;
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
|
||||
// Whether to run long tests (from smoke_test_fwd.sh)
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_LONG_TESTS)
|
||||
|
||||
#define CHECK_RESULT(result) \
|
||||
do \
|
||||
{ \
|
||||
if(result == fwd_result::no_instance) \
|
||||
GTEST_SKIP() << "No instance for current parameters"; \
|
||||
ASSERT_EQ(result, fwd_result::success); \
|
||||
} while(0)
|
||||
|
||||
const ck_tile::stream_config stream_config{
|
||||
nullptr, // stream_id_
|
||||
false, // time_kernel_
|
||||
1, // log_level_
|
||||
0, // cold_niters_
|
||||
1, // nrepeat_
|
||||
true, // is_gpu_timer_
|
||||
false, // flush_cache_
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
// range_q, range_k, range_v, range_p, range_o, squant
|
||||
#define QUANT_ARGS 1, 1, 1, 1, 1, squant
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
stream_config
|
||||
|
||||
auto EnableTestIf(bool condition)
|
||||
{
|
||||
return ValuesIn(condition ? std::vector<bool>{true} : std::vector<bool>{});
|
||||
}
|
||||
|
||||
class AllLong : public TestWithParam<
|
||||
std::tuple<bool,
|
||||
std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
mode_enum,
|
||||
bool,
|
||||
std::string,
|
||||
float,
|
||||
std::tuple<int, int, int, int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AllLong);
|
||||
|
||||
// Test cases from example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaFwd,
|
||||
AllLong,
|
||||
Combine(EnableTestIf(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_LONG_TESTS))),
|
||||
HDimValues,
|
||||
Bool(),
|
||||
IsVRowmajorValues,
|
||||
ModeValues,
|
||||
Bool(),
|
||||
Values("n", "e", "a"),
|
||||
Values(0.0f, 0.2f),
|
||||
Values(std::tuple{2, 2, 1, 16, -1, 55, 256, -1, "0"},
|
||||
std::tuple{1, 3, -1, -1, -1, 100, 51, -1, "0"},
|
||||
std::tuple{2, 1, -1, 16, -1, 99, 256, -1, "1"},
|
||||
std::tuple{1, 2, 1, -1, -1, 1024, 256, -1, "2"},
|
||||
std::tuple{2, 1, -1, -1, 24, 3, 99, -1, "2"},
|
||||
std::tuple{3, 2, 1, -1, -1, 200, 520, -1, "t:128,30"},
|
||||
std::tuple{2, 1, -1, -1, -1, 99, 32, -1, "b:4,35"},
|
||||
std::tuple{1, 2, 1, -1, -1, 33, 0, -1, "2"},
|
||||
std::tuple{1, 2, 1, -1, -1, 1, 10, 32, "2"})));
|
||||
|
||||
TEST_P(AllLong, Test)
|
||||
{
|
||||
auto [_, hdims, perm, is_v_rowmajor, mode, lse, bias_str, p_drop, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, hdim_q_, hdim_v_, seqlen_q, seqlen_k, seqlen_kpad, mask_str] =
|
||||
dims_mask;
|
||||
|
||||
hdim_q = hdim_q_ == -1 ? hdim_q : hdim_q_;
|
||||
hdim_v = hdim_v_ == -1 ? hdim_v : hdim_v_;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{seqlen_kpad}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
bias_str, // bias_str
|
||||
p_drop, // p_drop
|
||||
123, // drop_seed
|
||||
1024, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class HDimPadding
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::tuple<int, int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
HDimPadding,
|
||||
Combine(Values(std::tuple{24, 48},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
std::tuple{40, 64}),
|
||||
Bool(),
|
||||
IsVRowmajorValues,
|
||||
ModeValues,
|
||||
Values(std::tuple{1, 4, 2, 480, -1, -1, "0"},
|
||||
std::tuple{2, 2, -1, 300, 400, 512, "t:64,64"},
|
||||
std::tuple{1, 4, 1, 512, 201, 256, "1"},
|
||||
std::tuple{1, 2, -1, 900, 256, -1, "0"},
|
||||
std::tuple{2, 1, -1, 256, 256, -1, "1"})));
|
||||
|
||||
TEST_P(HDimPadding, Test)
|
||||
{
|
||||
auto [hdims, perm, is_v_rowmajor, mode, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, seqlen_kpad, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{seqlen_kpad}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
perm, // i_perm
|
||||
perm, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class ElementwiseBias
|
||||
: public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
mode_enum,
|
||||
std::string,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
ElementwiseBias,
|
||||
Combine(HDimValues,
|
||||
Bool(), // layout of bias is controlled by i_perm
|
||||
ModeValues,
|
||||
Values("e:0", "e:1", "e:2"),
|
||||
Values(std::tuple{1, 4, 2, 1024, 100, "0"},
|
||||
std::tuple{3, 2, -1, 128, 256, "2"},
|
||||
std::tuple{2, 2, -1, 130, 499, "t:50,64"})));
|
||||
|
||||
TEST_P(ElementwiseBias, Test)
|
||||
{
|
||||
auto [hdims, i_perm, mode, bias_str, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
def_is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
bias_str, // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Alibi : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
std::string,
|
||||
std::tuple<int, int, int, int, int>,
|
||||
std::string>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
Alibi,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values("a:0", "a:1"),
|
||||
Values(std::tuple{1, 3, 3, 1024, 1000},
|
||||
std::tuple{3, 5, 5, 128, 256},
|
||||
std::tuple{2, 8, 2, 300, 355}),
|
||||
Values("0", "t", "b", "t:50,64", "b:32,40")));
|
||||
|
||||
TEST_P(Alibi, Test)
|
||||
{
|
||||
auto [hdims, mode, bias_str, dims, mask_str] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k] = dims;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
true, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
def_is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
bias_str, // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class Dropout : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
mode_enum,
|
||||
float,
|
||||
std::tuple<uint64_t, uint64_t, bool>,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
Dropout,
|
||||
Combine(HDimValues,
|
||||
ModeValues,
|
||||
Values(0.123f, 0.5f),
|
||||
Values(std::tuple{10, 123, false},
|
||||
std::tuple{34534564645, 7876878876864, true}),
|
||||
Values(std::tuple{2, 4, 2, 280, 512, "0"},
|
||||
std::tuple{3, 2, 2, 256, 128, "1"},
|
||||
std::tuple{4, 3, 1, 100, 768, "2"})));
|
||||
|
||||
TEST_P(Dropout, Test)
|
||||
{
|
||||
auto [hdims, mode, p_drop, drop_seed_offset_prefs, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
false, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
def_is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
p_drop, // p_drop
|
||||
drop_seed, // drop_seed
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
|
||||
class PagedKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
mode_enum,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
PagedKV,
|
||||
Combine(SplitKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
ModeValues,
|
||||
Values(128, 256),
|
||||
Values(std::tuple{2, 3, 1, 200, 1024, "0"},
|
||||
std::tuple{3, 2, -1, 128, 768, "2"},
|
||||
std::tuple{2, 2, -1, 230, 899, "t:50,64"})));
|
||||
|
||||
TEST_P(PagedKV, Test)
|
||||
{
|
||||
auto [hdims, i_perm, is_v_rowmajor, mode, page_block_size, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
page_block_size, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#endif // CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
|
||||
class SplitKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
std::tuple<mode_enum, bool>,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
SplitKV,
|
||||
Combine(SplitKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
Values(std::tuple{mode_enum::batch, false},
|
||||
std::tuple{mode_enum::batch, true},
|
||||
std::tuple{mode_enum::group, false}),
|
||||
Values(3, 4),
|
||||
Values(std::tuple{4, 3, 1, 200, 1024, "0"},
|
||||
std::tuple{2, 2, -1, 512, 2000, "0"},
|
||||
std::tuple{3, 2, -1, 230, 899, "t:128,128"})));
|
||||
|
||||
TEST_P(SplitKV, Test)
|
||||
{
|
||||
auto [hdims, i_perm, is_v_rowmajor, mode_use_cache_batch_idx, num_splits, dims_mask] =
|
||||
GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [mode, use_cache_batch_idx] = mode_use_cache_batch_idx;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
0, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
false, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
0, // page_block_size
|
||||
use_cache_batch_idx, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
true, // is_rotary_interleaved
|
||||
num_splits, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
|
||||
class AppendKV : public TestWithParam<std::tuple<std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
std::tuple<int, bool>,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TestCkTileFmhaFwd,
|
||||
AppendKV,
|
||||
Combine(AppendKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
ValuesIn({std::tuple{0, true}, std::tuple{0, false}, std::tuple{128, false}}),
|
||||
Values(1, 64, -1),
|
||||
Values(std::tuple{3, 3, -1, 60, 129, "t:32,32"},
|
||||
std::tuple{3, 2, 2, 256, 256, "0"},
|
||||
std::tuple{2, 3, 1, 264, 265, "1"},
|
||||
std::tuple{4, 4, 2, 71, 64, "1"})));
|
||||
|
||||
TEST_P(AppendKV, Test)
|
||||
{
|
||||
auto [hdims,
|
||||
i_perm,
|
||||
is_v_rowmajor,
|
||||
page_block_size_use_cache_batch_idx,
|
||||
seqlen_knew,
|
||||
dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [page_block_size, use_cache_batch_idx] = page_block_size_use_cache_batch_idx;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_knew, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
0, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
def_lse, // lse
|
||||
page_block_size, // page_block_size
|
||||
use_cache_batch_idx, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
false, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
class AppendKVRoPE
|
||||
: public TestWithParam<std::tuple<bool,
|
||||
std::tuple<int, int>,
|
||||
bool,
|
||||
bool,
|
||||
std::tuple<int, bool>,
|
||||
int,
|
||||
std::tuple<int, int, int, int, int, std::string>>>
|
||||
{
|
||||
};
|
||||
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
|
||||
AppendKVRoPE,
|
||||
Combine(EnableTestIf(!std::is_same_v<DataTypeConfig, FmhaFwdFp8>),
|
||||
AppendKVHDimValues,
|
||||
Bool(), // layouts of k and v are controlled by i_perm
|
||||
IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor
|
||||
Values(std::tuple{0, false},
|
||||
std::tuple{16, true},
|
||||
std::tuple{32, false},
|
||||
std::tuple{-1, true}),
|
||||
Values(16, 50, -1),
|
||||
Values(std::tuple{2, 3, -1, 60, 129, "t:32,32"},
|
||||
std::tuple{1, 2, 1, 128, 55, "0"},
|
||||
std::tuple{3, 4, 2, 72, 128, "1"})));
|
||||
|
||||
TEST_P(AppendKVRoPE, Test)
|
||||
{
|
||||
auto [_, hdims, i_perm, is_v_rowmajor, rotary, seqlen_knew, dims_mask] = GetParam();
|
||||
auto [hdim_q, hdim_v] = hdims;
|
||||
auto [rotary_dim, is_rotary_interleaved] = rotary;
|
||||
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
|
||||
|
||||
rotary_dim = rotary_dim == -1 ? hdim_q : rotary_dim;
|
||||
seqlen_knew = seqlen_knew == -1 ? seqlen_k : seqlen_knew;
|
||||
|
||||
auto result = fmha_fwd_run<DataTypeConfig>(mode_enum::batch,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
{adjust_seqlen(seqlen_q)},
|
||||
{adjust_seqlen(seqlen_k)},
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
seqlen_knew, // seqlen_knew
|
||||
{-1}, // seqlen_kpads
|
||||
rotary_dim, // rotary_dim
|
||||
i_perm, // i_perm
|
||||
true, // o_perm
|
||||
0, // scale_s
|
||||
0, // logits_soft_cap
|
||||
is_v_rowmajor, // is_v_rowmajor
|
||||
true, // lse
|
||||
0, // page_block_size
|
||||
false, // use_cache_batch_idx
|
||||
"n", // bias_str
|
||||
0.0f, // p_drop
|
||||
0, // drop_seed
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
is_rotary_interleaved, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
CHECK_RESULT(result);
|
||||
}
|
||||
|
||||
#endif // CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
44
test/ck_tile/fmha/test_fmha_fwd_bf16.cpp
Normal file
44
test/ck_tile/fmha/test_fmha_fwd_bf16.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdBf16;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, 128},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, -1},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false, true);
|
||||
|
||||
const bool squant = false;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = true;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
44
test/ck_tile/fmha/test_fmha_fwd_fp16.cpp
Normal file
44
test/ck_tile/fmha/test_fmha_fwd_fp16.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdFp16;
|
||||
|
||||
const auto HDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, 128},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{192, 128},
|
||||
std::tuple{192, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{32, -1},
|
||||
std::tuple{64, -1},
|
||||
std::tuple{96, -1},
|
||||
std::tuple{128, -1},
|
||||
std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false, true);
|
||||
|
||||
const bool squant = false;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = true;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
43
test/ck_tile/fmha/test_fmha_fwd_fp8.cpp
Normal file
43
test/ck_tile/fmha/test_fmha_fwd_fp8.cpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
using DataTypeConfig = FmhaFwdFp8;
|
||||
|
||||
// Currently there are no fp8 instances for splitkv, pagedkv by default (the tests pass if such
|
||||
// instances are added), however the corresponding tests are not disabled (they will be skipped)
|
||||
// in case such instances will be added in the future.
|
||||
|
||||
const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
|
||||
// There are no fp8 instances with seqlen padding (mode_enum::group requires it)
|
||||
const auto ModeValues = Values(mode_enum::batch);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false);
|
||||
|
||||
const bool squant = true;
|
||||
const std::string init_method = "ufq";
|
||||
const bool def_lse = false;
|
||||
const bool def_is_v_rowmajor = false;
|
||||
|
||||
int adjust_seqlen(int seqlen)
|
||||
{
|
||||
// There are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests
|
||||
return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
}
|
||||
|
||||
#include "test_fmha_fwd.inc"
|
||||
Reference in New Issue
Block a user