fp8 to fp8 compare

This commit is contained in:
ltqin
2025-08-20 07:11:59 +00:00
parent 4b42551865
commit 8e37698fef
2 changed files with 51 additions and 157 deletions

View File

@@ -164,8 +164,8 @@ auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
template <>
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
{
unsigned rtol = 2.5e-1;
double atol = 2.5e-1;
unsigned rtol = 1.5e-1;
double atol = 1.5e-1;
return ck_tile::make_tuple(rtol, atol);
}
@@ -498,14 +498,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
using PDataType = typename TypeConfig::PDataType;
using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType;
using QHostDataType = typename TypeConfig::QHostDataType;
using KHostDataType = typename TypeConfig::KHostDataType;
using VHostDataType = typename TypeConfig::VHostDataType;
// float range_q = arg_parser.get_float("range_q");
// float range_k = arg_parser.get_float("range_k");
// float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
// float range_p = arg_parser.get_float("range_p");
// float range_o = arg_parser.get_float("range_o");
// accumulation numbers for performance evaluation
@@ -532,10 +529,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
num_byte += nhead * (sizeof(QHostDataType) * real_seqlen_q * hdim_q +
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof(ODataType) * real_seqlen_q * hdim_v);
num_byte += nhead_k * (sizeof(KHostDataType) * real_seqlen_k * hdim_q +
sizeof(VHostDataType) * hdim_v * real_seqlen_k);
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q +
sizeof(VDataType) * hdim_v * real_seqlen_k);
}
}
@@ -586,25 +583,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
ck_tile::HostTensor<QHostDataType> q_host(
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KHostDataType> k_host(
ck_tile::HostTensor<KDataType> k_host(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KHostDataType> knew_host(
ck_tile::HostTensor<KDataType> knew_host(
0 < seqlen_knew
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<VHostDataType> v_host(
ck_tile::HostTensor<VDataType> v_host(
0 < page_block_size
? (is_v_rowmajor
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
ck_tile::HostTensor<VHostDataType> vnew_host(
ck_tile::HostTensor<VDataType> vnew_host(
0 < seqlen_knew
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
@@ -658,47 +655,47 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(init_method == "ui" || init_method == "0")
{
ck_tile::FillUniformDistributionIntegerValue<QHostDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillUniformDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
else if(init_method == "ni")
{
ck_tile::FillNormalDistributionIntegerValue<QHostDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
else if(init_method == "uf" || init_method == "1")
{
ck_tile::FillUniformDistribution<QHostDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KHostDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<KHostDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillUniformDistribution<VHostDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<VHostDataType>{0.f, 1.f, seed}(vnew_host);
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
}
else if(init_method == "nf")
{
ck_tile::FillNormalDistribution<QHostDataType>{0.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistribution<KHostDataType>{0.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistribution<KHostDataType>{0.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistribution<VHostDataType>{0.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistribution<VHostDataType>{0.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
}
else if(init_method == "tf" || init_method == "2")
{
ck_tile::FillTrigValue<QHostDataType>{}(q_host);
ck_tile::FillTrigValue<KHostDataType>{}(k_host);
ck_tile::FillTrigValue<KHostDataType>{}(knew_host);
ck_tile::FillTrigValue<VHostDataType>{}(v_host);
ck_tile::FillTrigValue<VHostDataType>{}(vnew_host);
ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<KDataType>{}(knew_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
}
if(bias.type == bias_enum::alibi)
@@ -750,109 +747,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float p_dtype_max = v_dtype_max; // assume p and v is the same type
// float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
// float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
// float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
// float p_dtype_max = v_dtype_max; // assume p and v is the same type
// // float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
std::cout << "q_dtype_max: " << q_dtype_max << " k_dtype_max: " << k_dtype_max
<< " v_dtype_max: " << v_dtype_max << std::endl;
// std::cout << "q_dtype_max: " << q_dtype_max << " k_dtype_max: " << k_dtype_max
// << " v_dtype_max: " << v_dtype_max << std::endl;
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_p = p_dtype_max / range_p;
}
if constexpr(std::is_same_v<QDataType, QHostDataType>)
{
q_buf.ToDevice(q_host.data());
}
else
{
ck_tile::HostTensor<QDataType> q_host_device(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<QHostDataType>::min());
q_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
if(val > max_value)
max_value = val;
});
std::cout << "q max: " << max_value << std::endl;
float scale = q_dtype_max / max_value;
q_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
q_host_device(idx) = ck_tile::type_convert<QDataType>(val * scale);
});
q_buf.ToDevice(q_host_device.data());
scale_s = scale_s / scale;
}
if constexpr(std::is_same_v<KDataType, KHostDataType>)
{
k_buf.ToDevice(k_host.data());
}
else
{
ck_tile::HostTensor<KDataType> k_host_device(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<KHostDataType>::min());
k_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
if(val > max_value)
max_value = val;
});
std::cout << "k max: " << max_value << std::endl;
float scale = k_dtype_max / max_value;
k_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
k_host_device(idx) = ck_tile::type_convert<QDataType>(val * scale);
});
k_buf.ToDevice(k_host_device.data());
scale_s = scale_s / scale;
}
if constexpr(std::is_same_v<VDataType, VHostDataType>)
{
v_buf.ToDevice(v_host.data());
}
else
{
ck_tile::HostTensor<VDataType> v_host_device(
0 < page_block_size
? (is_v_rowmajor
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
: (is_v_rowmajor
? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<VHostDataType>::min());
v_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
if(val > max_value)
max_value = val;
});
std::cout << "v max: " << max_value << std::endl;
float scale = k_dtype_max / max_value;
v_host.ForEach([&](auto& self, auto idx) {
float val = ck_tile::type_convert<float>(self(idx));
v_host_device(idx) = ck_tile::type_convert<VDataType>(val * scale);
});
v_buf.ToDevice(v_host_device.data());
scale_o = (range_p / p_dtype_max) / scale;
}
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
knew_buf.ToDevice(knew_host.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
@@ -1302,7 +1209,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
float rp_undrop = 1.0 / p_undrop;
bool pass = true;
float scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
@@ -1318,9 +1224,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
ck_tile::HostTensor<QHostDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KHostDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
ck_tile::HostTensor<VHostDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
@@ -1467,13 +1373,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
// reference
ck_tile::
reference_batched_gemm<QHostDataType, KHostDataType, SaccDataType, SMPLComputeDataType>(
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host_ref,
k_host_ref,
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale));
ck_tile::scales(scale_s));
// std::cout << "q_host_ref: " << std::endl;
// show(std::cout, q_host_ref, nhead, real_seqlen_q, hdim_v);
@@ -1617,7 +1523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
}
ck_tile::reference_batched_gemm<PDataType, VHostDataType, OaccDataType, ODataType>(
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref,
v_host_ref,
o_host_ref,

View File

@@ -49,9 +49,6 @@ struct FmhaFwdTypeConfig<FmhaFwdFp16>
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using QHostDataType = QDataType;
using KHostDataType = KDataType;
using VHostDataType = VDataType;
using BiasDataType = ck_tile::half_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
@@ -68,9 +65,6 @@ struct FmhaFwdTypeConfig<FmhaFwdBf16>
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using QHostDataType = QDataType;
using KHostDataType = KDataType;
using VHostDataType = VDataType;
using BiasDataType = ck_tile::bf16_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
@@ -87,9 +81,6 @@ struct FmhaFwdTypeConfig<FmhaFwdFp8>
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using QHostDataType = ck_tile::bf16_t;
using KHostDataType = ck_tile::bf16_t;
using VHostDataType = ck_tile::bf16_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
@@ -106,9 +97,6 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t;
using QHostDataType = ck_tile::bf16_t;
using KHostDataType = ck_tile::bf16_t;
using VHostDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf8_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))