diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index cb6cd44f64..7f55d7412f 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -131,4 +131,4 @@ TBD ## FP8 experimental support As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. -Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. +Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 42a9d5148a..802c9e51d7 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -7,7 +7,8 @@ FWD_DTYPE_MAP = { "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + "fp8bf16": "FmhaFwdFp8Bf16", + "fp8fp32": "FmhaFwdFp8Fp32" } BWD_DTYPE_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d9452206e7..cfb96b7d53 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -163,7 +163,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - + const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} @@ -248,11 +248,11 @@ class FmhaFwdApiTrait: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - + @property def seqtune(self) -> str: if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: + else: return f'a.seqlen_q <= {self.bm0}' @property @@ -351,7 +351,7 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' - + if self.F_trload == 't' : n += '_trload' else: n += '_ntrload' @@ -378,7 +378,7 @@ class FmhaFwdApiPool: "t": "has_load_tr", "f": "true" } - + per_tr_load =str() for tr_load in ["t", "f"]: per_dtypes=str() @@ -550,12 +550,16 @@ class KernelComponentFactory: (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == 'fp8' or dtype == 'fp8bf16': return { (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } + elif dtype == 'fp8fp32': + return { + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } else: return None @@ -567,9 +571,9 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: + squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) @@ -589,11 +593,12 @@ class KernelComponentFactory: pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'bf8']: # TODO None else: @@ -674,25 +679,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 + if not cond: + continue + elif receipt == 888: + cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] + cond &= pipeline.F_vlayout == 'row' + cond &= hdim == 128 if not cond: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3b48b3d005..cee1505486 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 7b93e9654c..df6b422981 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) elif dtype in ['fp8', 'bf8']: - # TODO - None + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index c3bbb7a558..91cb9f55be 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -44,21 +44,15 @@ auto create_args(int argc, char* argv[]) .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" - "note when squant=1, this value will be modified by range_q/k") + "note when squant=1, this value will be modified") .insert("logits_soft_cap", "0", "attention logits soft capping value.") - .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") - .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") - .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") - .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") - .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") .insert("squant", "auto", "if using static quantization fusion or not. auto: fp8 will default use squant, " "other will not\n" "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " "P and O.\n" - "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " - "range_p, range_o") + "calculate scale_s, scale_p, scale_o auto") .insert("iperm", "1", "permute input\n" @@ -89,7 +83,7 @@ auto create_args(int argc, char* argv[]) "uf", "init method:\n ui or 0 - uniform random int\n ni - normalized random int" "\n uf or 1 - uniform random float\n nf - normalized random float" - "\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization") + "\n tf or 2 - trig float\n") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -148,11 +142,6 @@ auto run(const ck_tile::ArgParser& arg_parser) uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); bool drop_prefs = arg_parser.get_bool("drop_prefs"); std::string mask_str = arg_parser.get_str("mask"); - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); @@ -201,11 +190,6 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, - range_q, - range_k, - range_v, - range_p, - range_o, squant, is_rotary_interleaved, num_splits, @@ -237,6 +221,14 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } + else if(data_type == "fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp8fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } std::cerr << "Unsupported precision: " << data_type << std::endl; return -1; } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index df1e9e5699..c41e48e6aa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -41,6 +41,10 @@ struct FmhaFwdFp8Bf16 { }; +struct FmhaFwdFp8Fp32 +{ +}; + template struct FmhaFwdTypeConfig; @@ -108,6 +112,38 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::bf8_t; }; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 397245ab32..43f484fe14 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -50,20 +50,30 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == "ui" || init_method == "ni") - { - unsigned max_rounding_point_distance = 0; - double atol = 2e-3; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } - else - { - unsigned max_rounding_point_distance = 1; - double atol = 0.0625; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } + using TypeConfig = FmhaFwdTypeConfig; + using ODataType = typename TypeConfig::ODataType; + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + double rtol = 0; + double atol = 16 * (o_dtype_max > 240 ? 2 : 1); + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); } int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) @@ -157,11 +167,6 @@ fwd_result fmha_fwd_run(mode_enum mode, uint64_t drop_offset, bool drop_prefs, std::string mask_str, - float range_q, - float range_k, - float range_v, - float range_p, - float range_o, bool squant, bool is_rotary_interleaved, ck_tile::index_t num_splits, @@ -180,6 +185,10 @@ fwd_result fmha_fwd_run(mode_enum mode, return "fp8"; else if constexpr(std::is_same_v) return "bf8"; + else if constexpr(std::is_same_v) + return "fp8bf16"; + else if constexpr(std::is_same_v) + return "fp8fp32"; else static_assert(false); }(); @@ -367,22 +376,6 @@ fwd_result fmha_fwd_run(mode_enum mode, using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - - float scale_p = 1.f; - float scale_o = 1.f; - - if(squant) - { - scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); - scale_p = p_dtype_max / range_p; - scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); - } - // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = @@ -528,7 +521,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx ? std::array{batch} : std::array{1}); - + float max_o = 5.0; if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -576,32 +569,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") - { - // suitable for fp8 quantization - if(!squant) - { - std::cerr << "init method " << init_method << " can not be used without quantization" - << std::endl; - return fwd_result::invalid_args; - } - ck_tile::FillUniformDistribution{0.f, q_dtype_max, next_seed()}(q_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(k_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(knew_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(v_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(vnew_host); - - // bias_fp8 = qscale_bias * bias_fp32 - float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); - // Assume bias is in [0.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{0.f, qscale_bias, next_seed()}(bias_host); - } - else - { - std::cerr << "Unknown value for init argument: " << init_method << std::endl; - return fwd_result::invalid_args; - } - if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -625,8 +592,8 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); @@ -650,10 +617,79 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + float scale_p = 1.f; + float scale_o = 1.f; + if(squant) + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + // Q tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = q_dtype_max / max_value; + + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // K tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + float scale = k_dtype_max / max_value; + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // V tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = k_dtype_max / max_value; + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + + scale_o = (1.0 / p_dtype_max) / scale; + } + + scale_p = p_dtype_max; + + if constexpr(std::is_same_v) + { + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o = scale_o * o_dtype_max / max_o; + } + } + q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); - knew_buf.ToDevice(knew_host.data()); v_buf.ToDevice(v_host.data()); + knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); @@ -1103,7 +1139,9 @@ fwd_result fmha_fwd_run(mode_enum mode, lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); - constexpr bool supports_squant = std::is_same_v; + constexpr bool supports_squant = std::is_same_v || + std::is_same_v || + std::is_same_v; auto p_compute_element_func = [&]() { if constexpr(supports_squant) @@ -1113,9 +1151,11 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); auto oacc_element_func = [&]() { - if constexpr(supports_squant) + if constexpr(std::is_same_v && supports_squant) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); + else if constexpr(supports_squant) + return ck_tile::scales{scale_o}; else return ck_tile::identity{}; }(); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index c087a1fb3e..afd0c728c6 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -94,7 +94,30 @@ run_fp8_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - run_exe -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8bf16_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8fp32_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + done ; done ; done ; done } @@ -117,7 +140,9 @@ run_fp16_appendkv_tests() { set -x run_fp16_bf16_tests -# run_fp8_tests +run_fp8_tests +run_fp8bf16_tests +run_fp8fp32_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 6405ca50df..58fdad149a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1446,29 +1446,35 @@ struct FmhaFwdKernel auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{kargs.scale_o}); + else + return ck_tile::scales{kargs.scale_o}; + }(); + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index b883aad155..c402eaeac4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds on the previous + // iteration to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to + // reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP return cast_tile(ds); } }(); + // Finish loading bias_s to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, dbias); block_sync_lds(); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); @@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); - if constexpr(kHasBiasGrad) - { - // SGrad and BiasGrad use the same address in LDS. - block_sync_lds(); - } + // SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when + // bias is not used, loading ds in the hot loop to reuse LDS. + block_sync_lds(); store_tile(ds_lds_window, ds_gemm); block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 81950bd30a..41cb4fc306 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } s_waitcnt(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 16d9f695df..8c8d2af486 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } __builtin_amdgcn_s_waitcnt(3952); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 68ead7c765..ad9e2959f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; constexpr index_t smem_size_stage0_1 = smem_size_v; - constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot + smem_size_do + smem_size_lse + smem_size_d + max(smem_size_bias, smem_size_ds); diff --git a/script/launch_tests.sh b/script/launch_tests.sh index 5e71e25478..17a99e62a3 100755 --- a/script/launch_tests.sh +++ b/script/launch_tests.sh @@ -49,7 +49,7 @@ with open('$TEST_FILE', 'r') as f: if tests: # Extract just the filename after the last '/' clean_tests = [os.path.basename(test) for test in tests] - print('ctest -R \"' + '|'.join(clean_tests) + '\"') + print('ctest --output-on-failure -R \"' + '|'.join(clean_tests) + '\"') else: print('# No tests to run') ") @@ -57,5 +57,3 @@ with open('$TEST_FILE', 'r') as f: echo "$command" eval "$command" - - diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index f02ef1e55e..08abd3358d 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -32,9 +32,6 @@ const ck_tile::stream_config stream_config{ 1, // rotating_count_ }; -// range_q, range_k, range_v, range_p, range_o, squant -#define QUANT_ARGS 1, 1, 1, 1, 1, squant - #define COMMON_ARGS \ init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ stream_config @@ -117,7 +114,7 @@ TEST_P(AllLong, Test) 1024, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -179,7 +176,7 @@ TEST_P(HDimPadding, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -236,7 +233,7 @@ TEST_P(ElementwiseBias, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -292,7 +289,7 @@ TEST_P(Alibi, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -350,7 +347,7 @@ TEST_P(Dropout, Test) drop_offset, // drop_offset drop_prefs, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -410,7 +407,7 @@ TEST_P(PagedKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -476,7 +473,7 @@ TEST_P(SplitKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved num_splits, // num_splits COMMON_ARGS); @@ -548,7 +545,7 @@ TEST_P(AppendKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, false, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -618,7 +615,7 @@ TEST_P(AppendKVRoPE, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, is_rotary_interleaved, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp index 46ed8f4125..b99c304d1f 100644 --- a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp @@ -17,22 +17,21 @@ using DataTypeConfig = FmhaFwdFp8; // instances are added), however the corresponding tests are not disabled (they will be skipped) // in case such instances will be added in the future. -const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); -const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); -const auto AppendKVHDimValues = - Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto AppendKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); // There are no fp8 instances with seqlen padding (mode_enum::group requires it) const auto ModeValues = Values(mode_enum::batch); const auto IsVRowmajorValues = Values(false); -const bool squant = true; -const std::string init_method = "ufq"; +const auto squant = true; +const std::string init_method = "uf"; const bool def_lse = false; -const bool def_is_v_rowmajor = false; +const bool def_is_v_rowmajor = true; int adjust_seqlen(int seqlen) {