Add new instances

This commit is contained in:
Clement Lin
2025-04-25 01:35:45 +08:00
parent e0bfe71854
commit 89e33ed5ad
3 changed files with 59 additions and 8 deletions

View File

@@ -5,6 +5,8 @@ if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all")
set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS})
endif()
option(TOY_FA_FWD_OPT "Enable toy flash attention forward optimization" ON)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
@@ -34,6 +36,11 @@ add_executable(${EXAMPLE_REDUCE}
flash_attention_fwd.cpp
)
target_compile_definitions(${EXAMPLE_REDUCE}
PRIVATE
$<$<BOOL:${TOY_FA_FWD_OPT}>:TOY_FA_FWD_OPT=1>
)
target_include_directories(${EXAMPLE_REDUCE}
PRIVATE
${CMAKE_CURRENT_LIST_DIR}
@@ -43,7 +50,6 @@ target_sources(${EXAMPLE_REDUCE} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS})
message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}")
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS
-Wno-undefined-func-template
@@ -51,6 +57,12 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS
--offload-compress
)
# option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" ON)
# if(ENABLE_TOY_FA_FWD_OPT)
# message("Compiling with toy FA fwd optimization")
# target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT)
# endif()
target_compile_options(${EXAMPLE_REDUCE}
PRIVATE
${EXAMPLE_REDUCE_COMPILE_OPTIONS}

View File

@@ -209,7 +209,6 @@ struct FlashAttentionFwdImpl
auto v_lds_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
#endif
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };

View File

@@ -12,7 +12,7 @@ import copy
from dataclasses import dataclass
def get_if_str(size_, total, last_else=True):
if size_ == "head_dim_256_seq_4096":
if size_ == "head_dim_128_seq_16384":
return 'if'
else:
return 'else if'
@@ -368,7 +368,15 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
len_in_size = sum(len(b.instance_list) for b in blob_per_size)
size_cond = ""
if size_ == "head_dim_256_seq_4096":
if size_ == "head_dim_128_seq_16384":
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_16384":
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_128_seq_8192":
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_8192":
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_256_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)"
elif size_ == "head_dim_128_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
@@ -377,9 +385,17 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
elif size_ == "head_dim_32_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)"
elif size_ == "head_dim_128_seq_2048":
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_2048":
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_128_seq_1024":
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_1024":
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_128_seq_512":
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_512":
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 64 && a.N1 == 64)"
else:
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
@@ -404,11 +420,23 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
# Define kernel configurations for different size categories
trait_dict = {
"head_dim_128_seq_16384": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_16384": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_128_seq_8192": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_8192": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_256_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
],
"head_dim_128_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
@@ -417,10 +445,22 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32),
],
"head_dim_128_seq_2048": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_2048": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_128_seq_1024": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_1024": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_128_seq_512": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 128, 128, 128),
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_512": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
],
}