diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 27209a8b79..54bcf2c1ee 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -58,26 +58,16 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::fp8_t; }; +/// TODO: support specifying more elementwise functions for input/output tensors struct FmhaDefaultElementFunctions { - // using QElementFunction = ck_tile::identity; - // using KElementFunction = ck_tile::identity; - // using VElementFunction = ck_tile::identity; - // using BiasElementFunction = ck_tile::identity; - // using LSEElementFunction = ck_tile::identity; - // using SAccElementFunction = ck_tile::identity; using PComputeElementFunction = ck_tile::identity; using OAccElementFunction = ck_tile::identity; }; +/// TODO: support specifying more elementwise functions for input/output tensors struct FmhaF8StaticQuantizationElementFunctions { - // using QElementFunction = ck_tile::identity; - // using KElementFunction = ck_tile::identity; - // using VElementFunction = ck_tile::identity; - // using BiasElementFunction = ck_tile::identity; - // using LSEElementFunction = ck_tile::identity; - // using SAccElementFunction = ck_tile::identity; using PComputeElementFunction = ck_tile::scale; using OAccElementFunction = ck_tile::composer; }; @@ -146,12 +136,6 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; - // typename ElementFunctions::QElementFunction q_element_func; - // typename ElementFunctions::KElementFunction k_element_func; - // typename ElementFunctions::VElementFunction v_element_func; - // typename ElementFunctions::BiasElementFunction bias_element_func; - // typename ElementFunctions::LSEElementFunction lse_element_func; - // typename ElementFunctions::SAccElementFunction s_acc_element_func; typename ElementFunctions::PComputeElementFunction p_compute_element_func; typename ElementFunctions::OAccElementFunction o_acc_element_func; }; @@ -191,12 +175,6 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args) args.window_size_left, args.window_size_right, args.mask_type, - // args.q_element_func, - // args.k_element_func, - // args.v_element_func, - // args.bias_element_func, - // args.lse_element_func, - // args.s_acc_element_func, args.p_compute_element_func, args.o_acc_element_func); } @@ -234,12 +212,6 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args) args.window_size_left, args.window_size_right, args.mask_type, - // args.q_element_func, - // args.k_element_func, - // args.v_element_func, - // args.bias_element_func, - // args.lse_element_func, - // args.s_acc_element_func, args.p_compute_element_func, args.o_acc_element_func); }