add fp8bf16 data type

This commit is contained in:
ltqin
2025-08-26 11:44:12 +00:00
parent aa46050f7b
commit a22d90f71f
3 changed files with 46 additions and 10 deletions

View File

@@ -550,7 +550,7 @@ class KernelComponentFactory:
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
elif dtype == 'fp8' or dtype == 'bf8':
elif dtype == 'fp8' or dtype == 'fp8bf16':
return {
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
@@ -567,7 +567,7 @@ class KernelComponentFactory:
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
squant = 't' if dtype in ['fp8', 'fp8bf16'] else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
@@ -589,14 +589,14 @@ class KernelComponentFactory:
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
elif dtype in ['fp8', 'fp8bf16']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
for logits, mask, bias in itertools.product(["f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
elif dtype in ['fp8fp16', 'bf8']:
# TODO
None
else:
@@ -699,7 +699,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
if not cond:
continue
elif receipt == 888:
cond = dtype in ['fp8']
cond = dtype in ['fp8', 'fp8bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= hdim == 128
if not cond:

View File

@@ -158,6 +158,14 @@ auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
template <>
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
{
double rtol = 0;
double atol = 16;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaFwdFp8Bf16>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1.8e-1;
@@ -641,7 +649,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
? std::array<ck_tile::index_t, 1>{batch}
: std::array<ck_tile::index_t, 1>{1});
float max_o = 3.0;
if(init_method == "ui" || init_method == "0")
{
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
@@ -796,6 +804,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
scale_p = p_dtype_max;
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
{
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
scale_o = scale_o * o_dtype_max / max_o;
}
}
q_buf.ToDevice(q_host.data());
@@ -1241,14 +1255,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
return ck_tile::scales{scale_p};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
@@ -1650,6 +1666,10 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -89,7 +89,7 @@ struct FmhaFwdTypeConfig<FmhaFwdFp8>
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
using ODataType = ck_tile::fp8_t;
};
template <>
@@ -108,6 +108,22 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
using ODataType = ck_tile::bf8_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp8Bf16>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;