From 8e37698fef0e01ef5ad4afdd89dae2ce3ebecc68 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 20 Aug 2025 07:11:59 +0000 Subject: [PATCH] fp8 to fp8 compare --- example/ck_tile/01_fmha/fmha_fwd.cpp | 196 +++++++-------------------- example/ck_tile/01_fmha/fmha_fwd.hpp | 12 -- 2 files changed, 51 insertions(+), 157 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 83fd956591..61c76238fa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -164,8 +164,8 @@ auto get_elimit(std::string /*init_method*/) template <> auto get_elimit(std::string /*init_method*/) { - unsigned rtol = 2.5e-1; - double atol = 2.5e-1; + unsigned rtol = 1.5e-1; + double atol = 1.5e-1; return ck_tile::make_tuple(rtol, atol); } @@ -498,14 +498,11 @@ bool run(const ck_tile::ArgParser& arg_parser) using PDataType = typename TypeConfig::PDataType; using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; - using QHostDataType = typename TypeConfig::QHostDataType; - using KHostDataType = typename TypeConfig::KHostDataType; - using VHostDataType = typename TypeConfig::VHostDataType; // float range_q = arg_parser.get_float("range_q"); // float range_k = arg_parser.get_float("range_k"); // float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); + // float range_p = arg_parser.get_float("range_p"); // float range_o = arg_parser.get_float("range_o"); // accumulation numbers for performance evaluation @@ -532,10 +529,10 @@ bool run(const ck_tile::ArgParser& arg_parser) flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); - num_byte += nhead * (sizeof(QHostDataType) * real_seqlen_q * hdim_q + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + sizeof(ODataType) * real_seqlen_q * hdim_v); - num_byte += nhead_k * (sizeof(KHostDataType) * real_seqlen_k * hdim_q + - sizeof(VHostDataType) * hdim_v * real_seqlen_k); + num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k); } } @@ -586,25 +583,25 @@ bool run(const ck_tile::ArgParser& arg_parser) : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() : seqstart_k_with_padding_host.back())); - ck_tile::HostTensor q_host( + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - ck_tile::HostTensor k_host( + ck_tile::HostTensor k_host( 0 < page_block_size ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode - ck_tile::HostTensor knew_host( + ck_tile::HostTensor knew_host( 0 < seqlen_knew ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - ck_tile::HostTensor v_host( + ck_tile::HostTensor v_host( 0 < page_block_size ? (is_v_rowmajor ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); - ck_tile::HostTensor vnew_host( + ck_tile::HostTensor vnew_host( 0 < seqlen_knew ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew)) @@ -658,47 +655,47 @@ bool run(const ck_tile::ArgParser& arg_parser) if(init_method == "ui" || init_method == "0") { - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "ni") { - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } else if(init_method == "uf" || init_method == "1") { - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); - ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } else if(init_method == "nf") { - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); } else if(init_method == "tf" || init_method == "2") { - ck_tile::FillTrigValue{}(q_host); - ck_tile::FillTrigValue{}(k_host); - ck_tile::FillTrigValue{}(knew_host); - ck_tile::FillTrigValue{}(v_host); - ck_tile::FillTrigValue{}(vnew_host); + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(knew_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } if(bias.type == bias_enum::alibi) @@ -750,109 +747,19 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type + // float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + // float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + // float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + // float p_dtype_max = v_dtype_max; // assume p and v is the same type // // float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - std::cout << "q_dtype_max: " << q_dtype_max << " k_dtype_max: " << k_dtype_max - << " v_dtype_max: " << v_dtype_max << std::endl; + // std::cout << "q_dtype_max: " << q_dtype_max << " k_dtype_max: " << k_dtype_max + // << " v_dtype_max: " << v_dtype_max << std::endl; float scale_p = 1.f; float scale_o = 1.f; - if(squant) - { - scale_p = p_dtype_max / range_p; - } - - if constexpr(std::is_same_v) - { - q_buf.ToDevice(q_host.data()); - } - else - { - ck_tile::HostTensor q_host_device( - get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - q_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - - std::cout << "q max: " << max_value << std::endl; - float scale = q_dtype_max / max_value; - - q_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - q_host_device(idx) = ck_tile::type_convert(val * scale); - }); - - q_buf.ToDevice(q_host_device.data()); - scale_s = scale_s / scale; - } - - if constexpr(std::is_same_v) - { - k_buf.ToDevice(k_host.data()); - } - else - { - ck_tile::HostTensor k_host_device( - 0 < page_block_size - ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) - : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - k_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - std::cout << "k max: " << max_value << std::endl; - float scale = k_dtype_max / max_value; - k_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - k_host_device(idx) = ck_tile::type_convert(val * scale); - }); - - k_buf.ToDevice(k_host_device.data()); - scale_s = scale_s / scale; - } - - if constexpr(std::is_same_v) - { - v_buf.ToDevice(v_host.data()); - } - else - { - ck_tile::HostTensor v_host_device( - 0 < page_block_size - ? (is_v_rowmajor - ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) - : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) - : (is_v_rowmajor - ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) - : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); - - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - v_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - std::cout << "v max: " << max_value << std::endl; - - float scale = k_dtype_max / max_value; - v_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - v_host_device(idx) = ck_tile::type_convert(val * scale); - }); - - v_buf.ToDevice(v_host_device.data()); - scale_o = (range_p / p_dtype_max) / scale; - } - + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); @@ -1302,7 +1209,6 @@ bool run(const ck_tile::ArgParser& arg_parser) float rp_undrop = 1.0 / p_undrop; bool pass = true; - float scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); for(ck_tile::index_t wb = 0; wb < batch; ++wb) { const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; @@ -1318,9 +1224,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ? 0 : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); - ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); - ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); - ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); ck_tile::HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); @@ -1467,13 +1373,13 @@ bool run(const ck_tile::ArgParser& arg_parser) // reference ck_tile:: - reference_batched_gemm( + reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, ck_tile::identity{}, ck_tile::identity{}, - ck_tile::scales(scale)); + ck_tile::scales(scale_s)); // std::cout << "q_host_ref: " << std::endl; // show(std::cout, q_host_ref, nhead, real_seqlen_q, hdim_v); @@ -1617,7 +1523,7 @@ bool run(const ck_tile::ArgParser& arg_parser) p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); } - ck_tile::reference_batched_gemm( + ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, o_host_ref, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index c1b09656b6..bd5e110214 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -49,9 +49,6 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; using VDataType = ck_tile::half_t; - using QHostDataType = QDataType; - using KHostDataType = KDataType; - using VHostDataType = VDataType; using BiasDataType = ck_tile::half_t; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) @@ -68,9 +65,6 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t; - using QHostDataType = QDataType; - using KHostDataType = KDataType; - using VHostDataType = VDataType; using BiasDataType = ck_tile::bf16_t; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) @@ -87,9 +81,6 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t; using VDataType = ck_tile::fp8_t; - using QHostDataType = ck_tile::bf16_t; - using KHostDataType = ck_tile::bf16_t; - using VHostDataType = ck_tile::bf16_t; using BiasDataType = float; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) @@ -106,9 +97,6 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t; using VDataType = ck_tile::bf8_t; - using QHostDataType = ck_tile::bf16_t; - using KHostDataType = ck_tile::bf16_t; - using VHostDataType = ck_tile::bf16_t; using BiasDataType = ck_tile::bf8_t; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))