mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
rename & ensure thread safety
This commit is contained in:
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_FMHA_BWD}")
|
||||
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp)
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ struct p2
|
||||
unsigned int _p0;
|
||||
unsigned int _p1;
|
||||
}};
|
||||
struct __attribute__((packed)) fmha_bwd_asm_args
|
||||
struct __attribute__((packed)) fmha_bwd_v3_args
|
||||
{{
|
||||
void* ptr_dq;
|
||||
p2 _p0;
|
||||
@@ -224,7 +224,7 @@ struct __attribute__((packed)) fmha_bwd_asm_args
|
||||
p3 _p15;
|
||||
}};
|
||||
|
||||
struct __attribute__((packed)) fmha_bwd_xqa_asm_args
|
||||
struct __attribute__((packed)) fmha_bwd_xqa_v3_args
|
||||
{{
|
||||
void* ptr_dq;
|
||||
p2 _p0;
|
||||
@@ -270,7 +270,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_asm_args
|
||||
p3 _p20;
|
||||
}};
|
||||
|
||||
struct fmha_bwd_ext_traits
|
||||
struct fmha_bwd_v3_traits
|
||||
{{
|
||||
int b;
|
||||
int h;
|
||||
@@ -283,17 +283,17 @@ struct fmha_bwd_ext_traits
|
||||
int ts_kv;
|
||||
}};
|
||||
|
||||
class fmha_bwd_ext_kernel
|
||||
class fmha_bwd_v3_kernel
|
||||
{{
|
||||
public:
|
||||
fmha_bwd_ext_kernel(const std::string& name, unsigned char buffer[])
|
||||
fmha_bwd_v3_kernel(const std::string& name, unsigned char buffer[])
|
||||
{{
|
||||
HIP_CALL(hipModuleLoadData(&module, buffer));
|
||||
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
|
||||
}}
|
||||
|
||||
void
|
||||
launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_asm_args args, const ck_tile::stream_config& s) const
|
||||
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_args args, const ck_tile::stream_config& s) const
|
||||
{{
|
||||
size_t arg_size = sizeof(args);
|
||||
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
|
||||
@@ -303,12 +303,12 @@ class fmha_bwd_ext_kernel
|
||||
HIP_LAUNCH_PARAM_END}};
|
||||
|
||||
int bdx = 256;
|
||||
int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
|
||||
int gdy = fmha_ext_traits.h;
|
||||
int gdz = fmha_ext_traits.b;
|
||||
if(fmha_ext_traits.mask > 0)
|
||||
int gdx = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
|
||||
int gdy = fmha_v3_traits.h;
|
||||
int gdz = fmha_v3_traits.b;
|
||||
if(fmha_v3_traits.mask > 0)
|
||||
{{
|
||||
int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
|
||||
int num_tg = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
|
||||
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
|
||||
}}
|
||||
HIP_CALL(hipModuleLaunchKernel(kernel_func,
|
||||
@@ -325,7 +325,7 @@ class fmha_bwd_ext_kernel
|
||||
}}
|
||||
|
||||
void
|
||||
launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_xqa_asm_args args, const ck_tile::stream_config& s) const
|
||||
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_args args, const ck_tile::stream_config& s) const
|
||||
{{
|
||||
size_t arg_size = sizeof(args);
|
||||
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
|
||||
@@ -335,12 +335,12 @@ class fmha_bwd_ext_kernel
|
||||
HIP_LAUNCH_PARAM_END}};
|
||||
|
||||
int bdx = 256;
|
||||
int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
|
||||
int gdy = fmha_ext_traits.h;
|
||||
int gdz = fmha_ext_traits.b;
|
||||
if(fmha_ext_traits.mask > 0)
|
||||
int gdx = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
|
||||
int gdy = fmha_v3_traits.h;
|
||||
int gdz = fmha_v3_traits.b;
|
||||
if(fmha_v3_traits.mask > 0)
|
||||
{{
|
||||
int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
|
||||
int num_tg = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
|
||||
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
|
||||
}}
|
||||
HIP_CALL(hipModuleLaunchKernel(kernel_func,
|
||||
@@ -374,11 +374,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
}}
|
||||
|
||||
template <typename dot_do_o_trait_>
|
||||
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
|
||||
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << std::flush;
|
||||
fmha_bwd_asm_args args;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << std::flush;
|
||||
fmha_bwd_v3_args args;
|
||||
args.ptr_dq = a.dq_ptr;
|
||||
args.ptr_dk = a.dk_ptr;
|
||||
args.ptr_dv = a.dv_ptr;
|
||||
@@ -406,15 +406,15 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
|
||||
args.Hs = stride_head;
|
||||
args.BAs = stride_batch;
|
||||
args.Seqs = stride_seqlen;
|
||||
auto traits = fmha_bwd_ext_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
a.hdim_q,
|
||||
1,
|
||||
a.mask_type,
|
||||
32,
|
||||
128}};
|
||||
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
|
||||
auto traits = fmha_bwd_v3_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
a.hdim_q,
|
||||
1,
|
||||
a.mask_type,
|
||||
32,
|
||||
128}};
|
||||
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
|
||||
@@ -422,11 +422,11 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
|
||||
}}
|
||||
|
||||
template <typename dot_do_o_trait_>
|
||||
float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
|
||||
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << std::flush;
|
||||
fmha_bwd_xqa_asm_args args;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << std::flush;
|
||||
fmha_bwd_xqa_v3_args args;
|
||||
args.ptr_dq = a.dq_ptr;
|
||||
args.ptr_dk = a.dk_ptr;
|
||||
args.ptr_dv = a.dv_ptr;
|
||||
@@ -469,15 +469,15 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
|
||||
args.BAs_kv = stride_batch_kv;
|
||||
args.Seqs_kv = stride_seqlen_kv;
|
||||
args.Seqs_dkv = stride_seqlen_dkv;
|
||||
auto traits = fmha_bwd_ext_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
a.hdim_q,
|
||||
1,
|
||||
a.mask_type,
|
||||
32,
|
||||
128}};
|
||||
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
|
||||
auto traits = fmha_bwd_v3_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
a.hdim_q,
|
||||
1,
|
||||
a.mask_type,
|
||||
32,
|
||||
128}};
|
||||
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
|
||||
@@ -485,11 +485,11 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
|
||||
}}
|
||||
|
||||
template <typename dot_do_o_trait_, typename convert_dq_trait_>
|
||||
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
|
||||
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
fmha_bwd_asm_args args;
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
fmha_bwd_v3_args args;
|
||||
args.ptr_dq = a.dq_acc_ptr;
|
||||
args.ptr_dk = a.dk_ptr;
|
||||
args.ptr_dv = a.dv_ptr;
|
||||
@@ -517,15 +517,15 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
|
||||
args.Hs = stride_head;
|
||||
args.BAs = stride_batch;
|
||||
args.Seqs = stride_seqlen;
|
||||
auto traits = fmha_bwd_ext_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
a.hdim_q,
|
||||
1,
|
||||
a.mask_type,
|
||||
32,
|
||||
128}};
|
||||
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
|
||||
auto traits = fmha_bwd_v3_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
a.hdim_q,
|
||||
1,
|
||||
a.mask_type,
|
||||
32,
|
||||
128}};
|
||||
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
|
||||
@@ -536,139 +536,139 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
|
||||
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
if (t.uses_ext_asm == true){{
|
||||
if (t.uses_bwd_v3 == true){{
|
||||
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
|
||||
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
|
||||
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
|
||||
if(t.data_type.compare("fp16") == 0){{
|
||||
if(t.mask_type == mask_enum::no_mask){{
|
||||
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(t.is_asm_no_coex == true){{
|
||||
if(t.is_v3_spec == true){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_fp16_nocoex_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_nocoex_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_fp16_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_fp16_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_fp16_a16";
|
||||
const std::string bwd_v3_name = "bwd_v3_fp16_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(t.is_asm_no_coex == true){{
|
||||
if(t.is_v3_spec == true){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_fp16_nocoex_causal_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_nocoex_causal_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a16";
|
||||
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(t.data_type.compare("bf16") == 0){{
|
||||
if(t.mask_type == mask_enum::no_mask){{
|
||||
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(t.is_asm_no_coex == true){{
|
||||
if(t.is_v3_spec == true){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_nocoex_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_nocoex_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_asm_rtz_cvt == true){{
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_v3_rtz_cvt == true){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_a16_rtz";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_a16";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
|
||||
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(t.is_asm_no_coex == true){{
|
||||
if(t.is_v3_spec == true){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_nocoex_causal_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_nocoex_causal_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_asm_rtz_cvt == true){{
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
|
||||
if(t.is_v3_rtz_cvt == true){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16_rtz";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16";
|
||||
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_v3_name, io_perm);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
|
||||
@@ -92,17 +92,17 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
|
||||
"will not be used")
|
||||
.insert("ext_asm", "0", "if set to 1, some cases will call the ext asm dqdkdv kernel")
|
||||
.insert("bwd_v3", "0", "if set to 1, some cases will call the bwd v3 dqdkdv kernel")
|
||||
.insert(
|
||||
"asm_atomic_fp32",
|
||||
"v3_atomic_fp32",
|
||||
"1",
|
||||
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when ext_asm is set to 1")
|
||||
.insert("asm_no_coex",
|
||||
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when bwd_v3 is set to 1")
|
||||
.insert("v3_spec",
|
||||
"0",
|
||||
"if set to 1 will use non-coexectuion kernel when ext_asm is set to 1")
|
||||
.insert("asm_rtz_cvt",
|
||||
"if set to 1 will call the specialized v3 kernel when bwd_v3 is set to 1")
|
||||
.insert("v3_rtz_cvt",
|
||||
"0",
|
||||
"if set to 1 will use float to bf16 RTZ convert when ext_asm is set to 1");
|
||||
"if set to 1 will use float to bf16 RTZ convert when bwd_v3 is set to 1");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -187,14 +187,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
bool ext_asm = arg_parser.get_bool("ext_asm");
|
||||
bool asm_atomic_fp32 = arg_parser.get_bool("asm_atomic_fp32");
|
||||
bool asm_no_coex = arg_parser.get_bool("asm_no_coex");
|
||||
bool asm_rtz_cvt = arg_parser.get_bool("asm_rtz_cvt");
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
bool bwd_v3 = arg_parser.get_bool("bwd_v3");
|
||||
bool v3_atomic_fp32 = arg_parser.get_bool("v3_atomic_fp32");
|
||||
bool v3_spec = arg_parser.get_bool("v3_spec");
|
||||
bool v3_rtz_cvt = arg_parser.get_bool("v3_rtz_cvt");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -430,10 +430,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
ext_asm,
|
||||
asm_atomic_fp32,
|
||||
asm_no_coex,
|
||||
asm_rtz_cvt};
|
||||
bwd_v3,
|
||||
v3_atomic_fp32,
|
||||
v3_spec,
|
||||
v3_rtz_cvt};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
|
||||
@@ -438,10 +438,10 @@ struct fmha_bwd_traits
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
bool uses_ext_asm;
|
||||
bool is_asm_atomic_fp32;
|
||||
bool is_asm_no_coex;
|
||||
bool is_asm_rtz_cvt;
|
||||
bool uses_bwd_v3;
|
||||
bool is_v3_atomic_fp32;
|
||||
bool is_v3_spec;
|
||||
bool is_v3_rtz_cvt;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_bf16_nocoex_a32[] = {
|
||||
unsigned char bwd_bf16_spec_a32[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xB0, 0x7D, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_bf16_nocoex_causal_a32[] = {
|
||||
unsigned char bwd_bf16_spec_causal_a32[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x85, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_fp16_nocoex_a32[] = {
|
||||
unsigned char bwd_fp16_spec_a32[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x98, 0x5B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "fmha_hsaco.hpp"
|
||||
|
||||
unsigned char bwd_fp16_nocoex_causal_a32[] = {
|
||||
unsigned char bwd_fp16_spec_causal_a32[] = {
|
||||
0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x62, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
@@ -9,11 +9,11 @@ extern unsigned char bwd_bf16_a32[];
|
||||
extern unsigned char bwd_bf16_causal_a16[];
|
||||
extern unsigned char bwd_bf16_causal_a16_rtz[];
|
||||
extern unsigned char bwd_bf16_causal_a32[];
|
||||
extern unsigned char bwd_bf16_nocoex_a32[];
|
||||
extern unsigned char bwd_bf16_nocoex_causal_a32[];
|
||||
extern unsigned char bwd_bf16_spec_a32[];
|
||||
extern unsigned char bwd_bf16_spec_causal_a32[];
|
||||
extern unsigned char bwd_fp16_a16[];
|
||||
extern unsigned char bwd_fp16_a32[];
|
||||
extern unsigned char bwd_fp16_causal_a16[];
|
||||
extern unsigned char bwd_fp16_causal_a32[];
|
||||
extern unsigned char bwd_fp16_nocoex_a32[];
|
||||
extern unsigned char bwd_fp16_nocoex_causal_a32[];
|
||||
extern unsigned char bwd_fp16_spec_a32[];
|
||||
extern unsigned char bwd_fp16_spec_causal_a32[];
|
||||
|
||||
@@ -9,23 +9,23 @@ for hdim in 128 ; do
|
||||
|
||||
nhead=$((2048 / $hdim)) # follow fav2 setup
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -11,12 +11,12 @@ set -x
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 128 ; do
|
||||
for asm_atomic_fp32 in 0 1 ; do
|
||||
for asm_no_coex in 0 1 ; do
|
||||
for v3_atomic_fp32 in 0 1 ; do
|
||||
for v3_spec in 0 1 ; do
|
||||
for mask in 0 1 ; do
|
||||
|
||||
$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_spec=$v3_spec -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_spec=$v3_spec -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -13,8 +13,8 @@ for perm in 0 1 ; do
|
||||
for hdim in 128 ; do
|
||||
for mask in 0 1 ; do
|
||||
|
||||
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -asm_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -asm_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user