diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c64a19104e..741ef4062d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -164,6 +164,8 @@ FMHA_FWD_API_HEADER = """ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include +#include +#include #include @@ -220,17 +222,14 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw }} }} // namespace """ -FMHA_FWD_API_FOOTER_TEMPLATE = """ -float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunreachable-code" - if ({F_is_v3_enabled}) {{ +FMHA_FWD_API_FOOTER = """ +float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { + if (args.max_seqlen_q > 1) { float r = fmha_fwd_v3(traits, args, config); if (r >= 0) return r; - }} -#pragma clang diagnostic pop + } return fmha_fwd_v2(traits, args, config); -}} +} """ FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ @@ -1566,13 +1565,7 @@ def write_fwd_api( FMHA_FWD_API_HEADER, api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2), api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), - FMHA_FWD_API_FOOTER_TEMPLATE.format( - F_is_v3_enabled=BOOL_MAP[ - # NOTE: enable v3 pipelines when ready - 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) - # False - ] - ), + FMHA_FWD_API_FOOTER, ] ) update_file(autogen_dir / FMHA_FWD_API_FILENAME, content)