diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 139a7f6f1a..3aa5c6d7d5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -550,7 +550,7 @@ class KernelComponentFactory: (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == 'fp8' or dtype == 'fp8bf16': return { (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], @@ -567,7 +567,7 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' + squant = 't' if dtype in ['fp8', 'fp8bf16'] else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): @@ -589,14 +589,14 @@ class KernelComponentFactory: pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ['fp8', 'fp8bf16']: # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + for logits, mask, bias in itertools.product(["f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f')) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + elif dtype in ['fp8fp16', 'bf8']: # TODO None else: @@ -699,7 +699,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue elif receipt == 888: - cond = dtype in ['fp8'] + cond = dtype in ['fp8', 'fp8bf16'] cond &= pipeline.F_vlayout == 'row' cond &= hdim == 128 if not cond: diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 7e5f19dc24..a4380a2d60 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -158,6 +158,14 @@ auto get_elimit(std::string /*init_method*/) template <> auto get_elimit(std::string /*init_method*/) +{ + double rtol = 0; + double atol = 16; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) { double rtol = 1e-2; double atol = 1.8e-1; @@ -641,7 +649,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx ? std::array{batch} : std::array{1}); - + float max_o = 3.0; if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); @@ -796,6 +804,12 @@ bool run(const ck_tile::ArgParser& arg_parser) } scale_p = p_dtype_max; + + if constexpr(std::is_same_v) + { + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o = scale_o * o_dtype_max / max_o; + } } q_buf.ToDevice(q_host.data()); @@ -1241,14 +1255,16 @@ bool run(const ck_tile::ArgParser& arg_parser) randval_buf.FromDevice(randval_host.data()); auto p_compute_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) return ck_tile::scales{scale_p}; else return ck_tile::identity{}; }(); auto oacc_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) if constexpr(std::is_same_v) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); @@ -1650,6 +1666,10 @@ int main(int argc, char* argv[]) { return run(arg_parser) ? 0 : -2; } + else if(data_type == "fp8bf16") + { + return run(arg_parser) ? 0 : -2; + } return -3; } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 8c712b0aa7..b6455dd039 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -89,7 +89,7 @@ struct FmhaFwdTypeConfig using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; + using ODataType = ck_tile::fp8_t; }; template <> @@ -108,6 +108,22 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::bf8_t; }; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask;