mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
add bf16 rtne kernels
This commit is contained in:
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_FMHA_BWD}")
|
||||
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16_rtna.cpp hsaco/bwd_bf16_a16_rtne.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32_rtna.cpp hsaco/bwd_bf16_a32_rtne.cpp hsaco/bwd_bf16_a32_rtz.cpp hsaco/bwd_bf16_causal_a16_rtna.cpp hsaco/bwd_bf16_causal_a16_rtne.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32_rtna.cpp hsaco/bwd_bf16_causal_a32_rtne.cpp hsaco/bwd_bf16_causal_a32_rtz.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
|
||||
|
||||
@@ -624,7 +624,23 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
return r;
|
||||
}}
|
||||
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_v3_rtz_cvt == true){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtne, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtna, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
|
||||
@@ -632,31 +648,30 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32_rtz, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_v3_rtz_cvt == true){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtne, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtna, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
@@ -670,7 +685,23 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
return r;
|
||||
}}
|
||||
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_v3_rtz_cvt == true){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtne, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtna, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
|
||||
@@ -678,31 +709,30 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32_rtz, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_v3_rtz_cvt == true){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtne, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtna, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_v3_name, io_perm, 16, 192);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
@@ -982,14 +1012,14 @@ class FmhaBwdDQDKDVKernel:
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"]
|
||||
# '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -1032,7 +1062,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
continue
|
||||
if receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= bias in ['no']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
|
||||
@@ -100,9 +100,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v3_spec",
|
||||
"0",
|
||||
"if set to 1 will call the specialized v3 kernel when bwd_v3 is set to 1")
|
||||
.insert("v3_rtz_cvt",
|
||||
"0",
|
||||
"if set to 1 will use float to bf16 RTZ convert when bwd_v3 is set to 1");
|
||||
.insert("v3_bf16_cvt",
|
||||
"1",
|
||||
"float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -194,7 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bool bwd_v3 = arg_parser.get_bool("bwd_v3");
|
||||
bool v3_atomic_fp32 = arg_parser.get_bool("v3_atomic_fp32");
|
||||
bool v3_spec = arg_parser.get_bool("v3_spec");
|
||||
bool v3_rtz_cvt = arg_parser.get_bool("v3_rtz_cvt");
|
||||
int v3_bf16_cvt = arg_parser.get_int("v3_bf16_cvt");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -433,7 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bwd_v3,
|
||||
v3_atomic_fp32,
|
||||
v3_spec,
|
||||
v3_rtz_cvt};
|
||||
v3_bf16_cvt};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
|
||||
@@ -441,7 +441,7 @@ struct fmha_bwd_traits
|
||||
bool uses_bwd_v3;
|
||||
bool is_v3_atomic_fp32;
|
||||
bool is_v3_spec;
|
||||
bool is_v3_rtz_cvt;
|
||||
int how_v3_bf16_cvt;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_bf16_a16[] = {
|
||||
unsigned char bwd_bf16_a16_rtna[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x90, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
2451
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16_rtne.cpp
Normal file
2451
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16_rtne.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_bf16_a32[] = {
|
||||
unsigned char bwd_bf16_a32_rtna[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xE0, 0x8C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
2375
example/ck_tile/01_fmha/hsaco/bwd_bf16_a32_rtne.cpp
Normal file
2375
example/ck_tile/01_fmha/hsaco/bwd_bf16_a32_rtne.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_bf16_causal_a16[] = {
|
||||
unsigned char bwd_bf16_causal_a16_rtna[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x9C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
2651
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16_rtne.cpp
Normal file
2651
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16_rtne.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_bf16_causal_a32[] = {
|
||||
unsigned char bwd_bf16_causal_a32_rtna[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x58, 0x99, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
2575
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a32_rtne.cpp
Normal file
2575
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a32_rtne.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,13 +3,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
extern unsigned char bwd_bf16_a16[];
|
||||
extern unsigned char bwd_bf16_a16_rtna[];
|
||||
extern unsigned char bwd_bf16_a16_rtne[];
|
||||
extern unsigned char bwd_bf16_a16_rtz[];
|
||||
extern unsigned char bwd_bf16_a32[];
|
||||
extern unsigned char bwd_bf16_a32_rtna[];
|
||||
extern unsigned char bwd_bf16_a32_rtne[];
|
||||
extern unsigned char bwd_bf16_a32_rtz[];
|
||||
extern unsigned char bwd_bf16_causal_a16[];
|
||||
extern unsigned char bwd_bf16_causal_a16_rtna[];
|
||||
extern unsigned char bwd_bf16_causal_a16_rtne[];
|
||||
extern unsigned char bwd_bf16_causal_a16_rtz[];
|
||||
extern unsigned char bwd_bf16_causal_a32[];
|
||||
extern unsigned char bwd_bf16_causal_a32_rtna[];
|
||||
extern unsigned char bwd_bf16_causal_a32_rtne[];
|
||||
extern unsigned char bwd_bf16_causal_a32_rtz[];
|
||||
extern unsigned char bwd_bf16_spec_a32[];
|
||||
extern unsigned char bwd_bf16_spec_causal_a32[];
|
||||
|
||||
@@ -9,35 +9,47 @@ for hdim in 128 ; do
|
||||
|
||||
nhead=$((2048 / $hdim)) # follow fav2 setup
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v3_bf16_cvt=2 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -12,11 +12,11 @@ for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 128 ; do
|
||||
for v3_atomic_fp32 in 0 1 ; do
|
||||
for v3_rtz_cvt in 0 1 ; do
|
||||
for v3_bf16_cvt in 0 1 2 ; do
|
||||
for mask in 0 1 ; do
|
||||
|
||||
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_rtz_cvt=$v3_rtz_cvt -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_rtz_cvt=$v3_rtz_cvt -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_bf16_cvt=$v3_bf16_cvt -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_bf16_cvt=$v3_bf16_cvt -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user