From 89e33ed5ade9afde4dabf95b507ce500eeeaca69 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Fri, 25 Apr 2025 01:35:45 +0800 Subject: [PATCH] Add new instances --- .../CMakeLists.txt | 14 ++++- .../flash_attention_fwd_impl.hpp | 1 - .../generate.py | 52 ++++++++++++++++--- 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt index e9a697fde3..7b8f2a33b0 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt @@ -5,6 +5,8 @@ if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all") set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS}) endif() +option(TOY_FA_FWD_OPT "Enable toy flash attention forward optimization" ON) + execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} @@ -34,6 +36,11 @@ add_executable(${EXAMPLE_REDUCE} flash_attention_fwd.cpp ) +target_compile_definitions(${EXAMPLE_REDUCE} + PRIVATE + $<$:TOY_FA_FWD_OPT=1> +) + target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR} @@ -43,7 +50,6 @@ target_sources(${EXAMPLE_REDUCE} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS}) message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}") - set(EXAMPLE_REDUCE_COMPILE_OPTIONS) list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template @@ -51,6 +57,12 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS --offload-compress ) +# option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" ON) +# if(ENABLE_TOY_FA_FWD_OPT) +# message("Compiling with toy FA fwd optimization") +# target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT) +# endif() + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS} diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index fbca3a95ac..bd8c209383 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -209,7 +209,6 @@ struct FlashAttentionFwdImpl auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); #endif - // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 00bc91cadc..ce2f9d32c8 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -12,7 +12,7 @@ import copy from dataclasses import dataclass def get_if_str(size_, total, last_else=True): - if size_ == "head_dim_256_seq_4096": + if size_ == "head_dim_128_seq_16384": return 'if' else: return 'else if' @@ -368,7 +368,15 @@ float flash_attention_fwd_(const FlashAttnArgs