mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
add fp8bf16 data type
This commit is contained in:
@@ -550,7 +550,7 @@ class KernelComponentFactory:
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
elif dtype == 'fp8' or dtype == 'fp8bf16':
|
||||
return {
|
||||
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
@@ -567,7 +567,7 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
squant = 't' if dtype in ['fp8', 'fp8bf16'] else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
@@ -589,14 +589,14 @@ class KernelComponentFactory:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
elif dtype in ['fp8', 'fp8bf16']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
for logits, mask, bias in itertools.product(["f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
elif dtype in ['fp8fp16', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
@@ -699,7 +699,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 888:
|
||||
cond = dtype in ['fp8']
|
||||
cond = dtype in ['fp8', 'fp8bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
|
||||
@@ -158,6 +158,14 @@ auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 0;
|
||||
double atol = 16;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8Bf16>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.8e-1;
|
||||
@@ -641,7 +649,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
|
||||
? std::array<ck_tile::index_t, 1>{batch}
|
||||
: std::array<ck_tile::index_t, 1>{1});
|
||||
|
||||
float max_o = 3.0;
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
@@ -796,6 +804,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
scale_p = p_dtype_max;
|
||||
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
{
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
scale_o = scale_o * o_dtype_max / max_o;
|
||||
}
|
||||
}
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
@@ -1241,14 +1255,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
||||
return ck_tile::scales{scale_p};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
@@ -1650,6 +1666,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<FmhaFwdFp8Bf16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ struct FmhaFwdTypeConfig<FmhaFwdFp8>
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -108,6 +108,22 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
|
||||
Reference in New Issue
Block a user