mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK TILE] Fix FA build filter (#2369)
* Fix for fwd/bwd kernel build filter * fix bwd code * cmake depends & bwd filter order fix * revert unexpected reformat * Avoid change fmha bwd filter order for downstream compatibility * Revert unexpected changes --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Ding, Yi <yi.ding@amd.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
@@ -17,11 +17,30 @@ if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
)
|
||||
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
|
||||
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
|
||||
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS}
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd
|
||||
--receipt 3
|
||||
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
|
||||
)
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -29,8 +48,8 @@ if(ret AND NOT ret EQUAL 0)
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -44,14 +63,16 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
|
||||
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
@@ -73,7 +94,7 @@ target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
# NOTE: this is dangerous since will change the whole kernel to flush denormals
|
||||
# WIP with compiler team for an exp2 intrinsic..., then remove this
|
||||
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
|
||||
set(FMHA_FWD_FAST_EXP2 true)
|
||||
set(FMHA_FWD_FAST_EXP2 true)
|
||||
endif()
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
|
||||
@@ -82,9 +103,9 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1874,9 +1874,12 @@ struct FmhaBwdConvertQGradKernel
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s<QGradDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
|
||||
_SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_"
|
||||
+ _SS_(t2s<QGradDataType>::name) + "_"
|
||||
+ "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_"
|
||||
+ (kIsGroupMode ? "group" : "batch") + "_"
|
||||
+ ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn)
|
||||
+ (kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -100,7 +100,7 @@ struct FmhaFwdKernel
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
|
||||
Reference in New Issue
Block a user