mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Add new instances
This commit is contained in:
@@ -5,6 +5,8 @@ if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
option(TOY_FA_FWD_OPT "Enable toy flash attention forward optimization" ON)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
|
||||
@@ -34,6 +36,11 @@ add_executable(${EXAMPLE_REDUCE}
|
||||
flash_attention_fwd.cpp
|
||||
)
|
||||
|
||||
target_compile_definitions(${EXAMPLE_REDUCE}
|
||||
PRIVATE
|
||||
$<$<BOOL:${TOY_FA_FWD_OPT}>:TOY_FA_FWD_OPT=1>
|
||||
)
|
||||
|
||||
target_include_directories(${EXAMPLE_REDUCE}
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
@@ -43,7 +50,6 @@ target_sources(${EXAMPLE_REDUCE} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS})
|
||||
|
||||
message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}")
|
||||
|
||||
|
||||
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
@@ -51,6 +57,12 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS
|
||||
--offload-compress
|
||||
)
|
||||
|
||||
# option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" ON)
|
||||
# if(ENABLE_TOY_FA_FWD_OPT)
|
||||
# message("Compiling with toy FA fwd optimization")
|
||||
# target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT)
|
||||
# endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE}
|
||||
PRIVATE
|
||||
${EXAMPLE_REDUCE_COMPILE_OPTIONS}
|
||||
|
||||
@@ -209,7 +209,6 @@ struct FlashAttentionFwdImpl
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
@@ -12,7 +12,7 @@ import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
def get_if_str(size_, total, last_else=True):
|
||||
if size_ == "head_dim_256_seq_4096":
|
||||
if size_ == "head_dim_128_seq_16384":
|
||||
return 'if'
|
||||
else:
|
||||
return 'else if'
|
||||
@@ -368,7 +368,15 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
len_in_size = sum(len(b.instance_list) for b in blob_per_size)
|
||||
|
||||
size_cond = ""
|
||||
if size_ == "head_dim_256_seq_4096":
|
||||
if size_ == "head_dim_128_seq_16384":
|
||||
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_16384":
|
||||
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_128_seq_8192":
|
||||
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_8192":
|
||||
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_256_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)"
|
||||
elif size_ == "head_dim_128_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
@@ -377,9 +385,17 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
elif size_ == "head_dim_32_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)"
|
||||
elif size_ == "head_dim_128_seq_2048":
|
||||
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_2048":
|
||||
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_128_seq_1024":
|
||||
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_1024":
|
||||
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_128_seq_512":
|
||||
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_512":
|
||||
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 64 && a.N1 == 64)"
|
||||
else:
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
|
||||
@@ -404,11 +420,23 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
|
||||
# Define kernel configurations for different size categories
|
||||
trait_dict = {
|
||||
"head_dim_128_seq_16384": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_16384": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_128_seq_8192": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_8192": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_256_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"head_dim_128_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
|
||||
@@ -417,10 +445,22 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32),
|
||||
],
|
||||
"head_dim_128_seq_2048": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_2048": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_128_seq_1024": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_1024": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_128_seq_512": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 128, 128, 128),
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_512": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user