mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
fp8 to fp8 compare
This commit is contained in:
@@ -164,8 +164,8 @@ auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
|
||||
{
|
||||
unsigned rtol = 2.5e-1;
|
||||
double atol = 2.5e-1;
|
||||
unsigned rtol = 1.5e-1;
|
||||
double atol = 1.5e-1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
@@ -498,14 +498,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using PDataType = typename TypeConfig::PDataType;
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
using QHostDataType = typename TypeConfig::QHostDataType;
|
||||
using KHostDataType = typename TypeConfig::KHostDataType;
|
||||
using VHostDataType = typename TypeConfig::VHostDataType;
|
||||
|
||||
// float range_q = arg_parser.get_float("range_q");
|
||||
// float range_k = arg_parser.get_float("range_k");
|
||||
// float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
// float range_p = arg_parser.get_float("range_p");
|
||||
// float range_o = arg_parser.get_float("range_o");
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
@@ -532,10 +529,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
|
||||
num_byte += nhead * (sizeof(QHostDataType) * real_seqlen_q * hdim_q +
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
||||
sizeof(ODataType) * real_seqlen_q * hdim_v);
|
||||
num_byte += nhead_k * (sizeof(KHostDataType) * real_seqlen_k * hdim_q +
|
||||
sizeof(VHostDataType) * hdim_v * real_seqlen_k);
|
||||
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q +
|
||||
sizeof(VDataType) * hdim_v * real_seqlen_k);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -586,25 +583,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
|
||||
ck_tile::HostTensor<QHostDataType> q_host(
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KHostDataType> k_host(
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KHostDataType> knew_host(
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
0 < seqlen_knew
|
||||
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<VHostDataType> v_host(
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
0 < page_block_size
|
||||
? (is_v_rowmajor
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
|
||||
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
|
||||
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
||||
ck_tile::HostTensor<VHostDataType> vnew_host(
|
||||
ck_tile::HostTensor<VDataType> vnew_host(
|
||||
0 < seqlen_knew
|
||||
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
|
||||
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
|
||||
@@ -658,47 +655,47 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QHostDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QHostDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KHostDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VHostDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "uf" || init_method == "1")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QHostDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KHostDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KHostDataType>{0.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VHostDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VHostDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
ck_tile::FillNormalDistribution<QHostDataType>{0.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<KHostDataType>{0.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<KHostDataType>{0.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistribution<VHostDataType>{0.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistribution<VHostDataType>{0.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(knew_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(vnew_host);
|
||||
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "tf" || init_method == "2")
|
||||
{
|
||||
ck_tile::FillTrigValue<QHostDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KHostDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<KHostDataType>{}(knew_host);
|
||||
ck_tile::FillTrigValue<VHostDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<VHostDataType>{}(vnew_host);
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(knew_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
@@ -750,109 +747,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
||||
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
// float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
// float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
// float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
// float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
// // float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
std::cout << "q_dtype_max: " << q_dtype_max << " k_dtype_max: " << k_dtype_max
|
||||
<< " v_dtype_max: " << v_dtype_max << std::endl;
|
||||
// std::cout << "q_dtype_max: " << q_dtype_max << " k_dtype_max: " << k_dtype_max
|
||||
// << " v_dtype_max: " << v_dtype_max << std::endl;
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
|
||||
if(squant)
|
||||
{
|
||||
scale_p = p_dtype_max / range_p;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<QDataType, QHostDataType>)
|
||||
{
|
||||
q_buf.ToDevice(q_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> q_host_device(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<QHostDataType>::min());
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
|
||||
std::cout << "q max: " << max_value << std::endl;
|
||||
float scale = q_dtype_max / max_value;
|
||||
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
q_host_device(idx) = ck_tile::type_convert<QDataType>(val * scale);
|
||||
});
|
||||
|
||||
q_buf.ToDevice(q_host_device.data());
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<KDataType, KHostDataType>)
|
||||
{
|
||||
k_buf.ToDevice(k_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::HostTensor<KDataType> k_host_device(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<KHostDataType>::min());
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
std::cout << "k max: " << max_value << std::endl;
|
||||
float scale = k_dtype_max / max_value;
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
k_host_device(idx) = ck_tile::type_convert<QDataType>(val * scale);
|
||||
});
|
||||
|
||||
k_buf.ToDevice(k_host_device.data());
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<VDataType, VHostDataType>)
|
||||
{
|
||||
v_buf.ToDevice(v_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::HostTensor<VDataType> v_host_device(
|
||||
0 < page_block_size
|
||||
? (is_v_rowmajor
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
|
||||
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
|
||||
: (is_v_rowmajor
|
||||
? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
||||
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<VHostDataType>::min());
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
std::cout << "v max: " << max_value << std::endl;
|
||||
|
||||
float scale = k_dtype_max / max_value;
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
v_host_device(idx) = ck_tile::type_convert<VDataType>(val * scale);
|
||||
});
|
||||
|
||||
v_buf.ToDevice(v_host_device.data());
|
||||
scale_o = (range_p / p_dtype_max) / scale;
|
||||
}
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
@@ -1302,7 +1209,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
bool pass = true;
|
||||
float scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
@@ -1318,9 +1224,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
? 0
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
|
||||
|
||||
ck_tile::HostTensor<QHostDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
||||
ck_tile::HostTensor<KHostDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
ck_tile::HostTensor<VHostDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
@@ -1467,13 +1373,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// reference
|
||||
ck_tile::
|
||||
reference_batched_gemm<QHostDataType, KHostDataType, SaccDataType, SMPLComputeDataType>(
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale));
|
||||
ck_tile::scales(scale_s));
|
||||
// std::cout << "q_host_ref: " << std::endl;
|
||||
// show(std::cout, q_host_ref, nhead, real_seqlen_q, hdim_v);
|
||||
|
||||
@@ -1617,7 +1523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VHostDataType, OaccDataType, ODataType>(
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
|
||||
@@ -49,9 +49,6 @@ struct FmhaFwdTypeConfig<FmhaFwdFp16>
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using QHostDataType = QDataType;
|
||||
using KHostDataType = KDataType;
|
||||
using VHostDataType = VDataType;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
@@ -68,9 +65,6 @@ struct FmhaFwdTypeConfig<FmhaFwdBf16>
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using QHostDataType = QDataType;
|
||||
using KHostDataType = KDataType;
|
||||
using VHostDataType = VDataType;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
@@ -87,9 +81,6 @@ struct FmhaFwdTypeConfig<FmhaFwdFp8>
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using QHostDataType = ck_tile::bf16_t;
|
||||
using KHostDataType = ck_tile::bf16_t;
|
||||
using VHostDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
@@ -106,9 +97,6 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using QHostDataType = ck_tile::bf16_t;
|
||||
using KHostDataType = ck_tile::bf16_t;
|
||||
using VHostDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
|
||||
Reference in New Issue
Block a user