mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
enable hd64 bf16 causal
This commit is contained in:
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_FMHA_BWD}")
|
||||
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_hd64_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_hd64_bf16_a16_rtna.cpp hsaco/bwd_hd64_bf16_a16_rtne.cpp hsaco/bwd_hd64_bf16_a16_rtz.cpp hsaco/bwd_hd64_bf16_causal_a16_rtna.cpp hsaco/bwd_hd64_bf16_causal_a16_rtne.cpp hsaco/bwd_hd64_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
|
||||
|
||||
@@ -741,13 +741,52 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
if(t.data_type.compare("bf16") == 0){{
|
||||
if(t.mask_type == mask_enum::no_mask){{
|
||||
if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.how_v3_bf16_cvt == 1){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtne, bwd_v3_name, io_perm, 32, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtna, bwd_v3_name, io_perm, 32, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_a16_rtz, bwd_v3_name, io_perm, 32, 192);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_causal_a16_rtne, bwd_v3_name, io_perm, 32, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_causal_a16_rtna, bwd_v3_name, io_perm, 32, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_hd64_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 32, 192);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
|
||||
2370
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtne.cpp
Normal file
2370
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtne.cpp
Normal file
File diff suppressed because it is too large
Load Diff
1877
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtz.cpp
Normal file
1877
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a16_rtz.cpp
Normal file
File diff suppressed because it is too large
Load Diff
2599
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_causal_a16_rtna.cpp
Normal file
2599
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_causal_a16_rtna.cpp
Normal file
File diff suppressed because it is too large
Load Diff
2744
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_causal_a16_rtne.cpp
Normal file
2744
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_causal_a16_rtne.cpp
Normal file
File diff suppressed because it is too large
Load Diff
2251
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_causal_a16_rtz.cpp
Normal file
2251
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_causal_a16_rtz.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -24,3 +24,8 @@ extern unsigned char bwd_fp16_causal_a32[];
|
||||
extern unsigned char bwd_fp16_spec_a32[];
|
||||
extern unsigned char bwd_fp16_spec_causal_a32[];
|
||||
extern unsigned char bwd_hd64_bf16_a16_rtna[];
|
||||
extern unsigned char bwd_hd64_bf16_a16_rtne[];
|
||||
extern unsigned char bwd_hd64_bf16_a16_rtz[];
|
||||
extern unsigned char bwd_hd64_bf16_causal_a16_rtna[];
|
||||
extern unsigned char bwd_hd64_bf16_causal_a16_rtne[];
|
||||
extern unsigned char bwd_hd64_bf16_causal_a16_rtz[];
|
||||
|
||||
Reference in New Issue
Block a user