diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 265dc5bb51..5970aa13cb 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -49,9 +49,6 @@ auto create_args(int argc, char* argv[]) .insert("d", "128", "head dim for q, k") .insert("d_v", "0", "head dim for v, 0 means equal to d") .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") - .insert("descale_q", "1", "scale factor for fp8 quantization") - .insert("descale_k", "1", "scale factor for fp8 quantization") - .insert("descale_v", "1", "scale factor for fp8 quantization") .insert("iperm", "1", "permute input\n" @@ -140,10 +137,6 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale == .0f) scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - float descale_q = arg_parser.get_float("descale_q"); - float descale_k = arg_parser.get_float("descale_k"); - float descale_v = arg_parser.get_float("descale_v"); - std::string vlayout = arg_parser.get_str("vlayout"); bool use_bias = arg_parser.get_bool("bias"); bool lse = arg_parser.get_bool("lse"); @@ -384,9 +377,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.y, mask.x, ck_tile::identity{}, - ck_tile::identity{}, - descale_q * descale_k, - descale_v}; + ck_tile::identity{}}; }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 65be4538e3..fef65395f7 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -125,8 +125,6 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t max_seqlen_q, float scale, - float descale_qk, - float descale_sv, bool i_perm, bool o_perm, ck_tile::index_t mask_y, @@ -199,9 +197,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, nhead_stride_lse, nhead_stride_o, mask_y, - mask_x, - descale_qk, - descale_sv); + mask_x); } else { // create batch mode kernel arguments @@ -235,9 +231,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, batch_stride_lse, batch_stride_o, mask_y, - mask_x, - descale_qk, - descale_sv); + mask_x); } }(); @@ -266,8 +260,6 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t max_seqlen_q; float scale; - float descale_qk; - float descale_sv; bool i_perm; bool o_perm; ck_tile::index_t mask_y; @@ -324,8 +316,6 @@ struct fmha_fwd_args // typename ElementFunctions::SAccElementFunction s_acc_element_func; typename ElementFunctions::PComputeElementFunction p_compute_element_func; typename ElementFunctions::OAccElementFunction o_acc_element_func; - float descale_qk; - float descale_sv; }; template @@ -369,9 +359,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args) // args.lse_element_func, // args.s_acc_element_func, args.p_compute_element_func, - args.o_acc_element_func, - args.descale_qk, - args.descale_sv); + args.o_acc_element_func); } else { // create batch mode kernel arguments @@ -413,9 +401,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args) // args.lse_element_func, // args.s_acc_element_func, args.p_compute_element_func, - args.o_acc_element_func, - args.descale_qk, - args.descale_sv); + args.o_acc_element_func); } }(); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 74086b61ab..07ee405174 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -42,7 +42,6 @@ LAYOUT_MAP = { PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", - "qr_fp8" : "ck_tile::BlockFmhaPipelineQRKSVSFp8", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", } @@ -212,7 +211,7 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr']: if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0' else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @@ -228,7 +227,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr']: if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0' else : return f'a.hdim_q % {self.bk0blen} == 0' else: assert False @@ -239,7 +238,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr']: if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0' else : return f'a.hdim_v % {self.bk0blen} == 0' else: assert False @@ -450,7 +449,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF elif dtype in ['fp8', 'bf8']: # no need lse kernels for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask)) else: assert False return pipelines diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index bc35525a71..7e3cfae9db 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -166,13 +166,6 @@ struct FmhaFwdKernel ck_tile::index_t mask_y, mask_x; }; - struct FmhaFwdFP8Kargs - { - float descale_qk; // q*k - float descale_sv; // s*v - // float * o_amax_ptr; - }; - struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -188,8 +181,7 @@ struct FmhaFwdKernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -201,8 +193,7 @@ struct FmhaFwdKernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -251,9 +242,7 @@ struct FmhaFwdKernel // LSEElementFunction lse_element_func, // SAccElementFunction s_acc_element_func, PComputeElementFunction p_compute_element_func, - OAccElementFunction o_acc_element_func, - float descale_qk, - float descale_sv) + OAccElementFunction o_acc_element_func) { Kargs kargs{{q_ptr, k_ptr, @@ -288,7 +277,6 @@ struct FmhaFwdKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse - {}, // placeholder for fp8 args batch_stride_q, batch_stride_k, batch_stride_v, @@ -312,11 +300,6 @@ struct FmhaFwdKernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } - if constexpr(kIsFp8) - { - kargs.descale_qk = descale_qk; - kargs.descale_sv = descale_sv; - } return kargs; } @@ -356,9 +339,7 @@ struct FmhaFwdKernel // LSEElementFunction lse_element_func, // SAccElementFunction s_acc_element_func, PComputeElementFunction p_compute_element_func, - OAccElementFunction o_acc_element_func, - float descale_qk, - float descale_sv) + OAccElementFunction o_acc_element_func) { Kargs kargs{{q_ptr, k_ptr, @@ -393,7 +374,6 @@ struct FmhaFwdKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse - {}, // placeholder for fp8 args reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -414,11 +394,6 @@ struct FmhaFwdKernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } - if constexpr(kIsFp8) - { - kargs.descale_qk = descale_qk; - kargs.descale_sv = descale_sv; - } return kargs; } @@ -702,38 +677,22 @@ struct FmhaFwdKernel }(); auto o_acc_tile = [&]() { - if constexpr(kIsFp8) - { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_dram_window, - mask, - kargs.scale, - kargs.descale_qk, - kargs.descale_sv, - smem_ptr); - } - else - { - return FmhaPipeline{}(q_dram_window, - identity{}, - k_dram_window, - identity{}, - v_dram_window, - identity{}, - bias_dram_window, - identity{}, - lse_dram_window, - identity{}, - identity{}, - kargs.p_compute_element_func, - kargs.o_acc_element_func, - mask, - kargs.scale, - smem_ptr); - } + return FmhaPipeline{}(q_dram_window, + identity{}, + k_dram_window, + identity{}, + v_dram_window, + identity{}, + bias_dram_window, + identity{}, + lse_dram_window, + identity{}, + identity{}, + kargs.p_compute_element_func, + kargs.o_acc_element_func, + mask, + kargs.scale, + smem_ptr); }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 5476282e04..78527f4633 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -9,6 +9,7 @@ namespace ck_tile { +// deprecated pipeline // This pipeline is qkv all located in LDS template struct BlockFmhaPipelineQRKSVSFp8