diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 972eb1208c..8e8e19f089 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -372,90 +372,50 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - if constexpr(std::is_same_v) - return fmha_fwd_args{ - q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - nullptr, - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_lse, - batch_stride_o, - mask.left, - mask.right, - static_cast(mask.type), - pcompute_element_func, - oacc_element_func}; - else - return fmha_fwd_args{ - q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - nullptr, - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_lse, - batch_stride_o, - mask.left, - mask.right, - static_cast(mask.type), - pcompute_element_func, - oacc_element_func}; + using ElementFunctions = std::conditional_t, + FmhaF8StaticQuantizationElementFunctions, + FmhaDefaultElementFunctions>; + + return fmha_fwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + pcompute_element_func, + oacc_element_func}; }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);