mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '2aec38f9ec67bfbdccbdb3a5c25913e5a9ba6136' into develop
This commit is contained in:
@@ -131,4 +131,4 @@ TBD
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
|
||||
Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later.
|
||||
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
|
||||
|
||||
@@ -7,7 +7,8 @@ FWD_DTYPE_MAP = {
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
|
||||
@@ -163,7 +163,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
@@ -248,11 +248,11 @@ class FmhaFwdApiTrait:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true
|
||||
else:
|
||||
else:
|
||||
return f'a.seqlen_q <= {self.bm0}'
|
||||
|
||||
@property
|
||||
@@ -351,7 +351,7 @@ class FmhaFwdPipeline:
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
|
||||
if self.F_trload == 't' : n += '_trload'
|
||||
else: n += '_ntrload'
|
||||
|
||||
@@ -378,7 +378,7 @@ class FmhaFwdApiPool:
|
||||
"t": "has_load_tr",
|
||||
"f": "true"
|
||||
}
|
||||
|
||||
|
||||
per_tr_load =str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes=str()
|
||||
@@ -550,12 +550,16 @@ class KernelComponentFactory:
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
elif dtype == 'fp8' or dtype == 'fp8bf16':
|
||||
return {
|
||||
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
elif dtype == 'fp8fp32':
|
||||
return {
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -567,9 +571,9 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
squant = 'f'
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
@@ -589,11 +593,12 @@ class KernelComponentFactory:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
@@ -674,25 +679,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 888:
|
||||
cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
|
||||
@@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
|
||||
@@ -44,21 +44,15 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("scale_s",
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
"note when squant=1, this value will be modified by range_q/k")
|
||||
"note when squant=1, this value will be modified")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
|
||||
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
|
||||
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
|
||||
.insert("squant",
|
||||
"auto",
|
||||
"if using static quantization fusion or not. auto: fp8 will default use squant, "
|
||||
"other will not\n"
|
||||
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
|
||||
"P and O.\n"
|
||||
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
|
||||
"range_p, range_o")
|
||||
"calculate scale_s, scale_p, scale_o auto")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
@@ -89,7 +83,7 @@ auto create_args(int argc, char* argv[])
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization")
|
||||
"\n tf or 2 - trig float\n")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
@@ -148,11 +142,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
@@ -201,11 +190,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
range_q,
|
||||
range_k,
|
||||
range_v,
|
||||
range_p,
|
||||
range_o,
|
||||
squant,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
@@ -237,6 +221,14 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8fp32")
|
||||
{
|
||||
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,10 @@ struct FmhaFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Fp32
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
|
||||
@@ -108,6 +112,38 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8Fp32>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = float;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
|
||||
@@ -50,20 +50,30 @@ auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
using TypeConfig = FmhaFwdTypeConfig<FmhaFwdFp8>;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
double rtol = 0;
|
||||
double atol = 16 * (o_dtype_max > 240 ? 2 : 1);
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8Bf16>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.8e-1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8Fp32>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.8e-1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
||||
@@ -157,11 +167,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
uint64_t drop_offset,
|
||||
bool drop_prefs,
|
||||
std::string mask_str,
|
||||
float range_q,
|
||||
float range_k,
|
||||
float range_v,
|
||||
float range_p,
|
||||
float range_o,
|
||||
bool squant,
|
||||
bool is_rotary_interleaved,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -180,6 +185,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
return "fp8";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf8>)
|
||||
return "bf8";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
||||
return "fp8bf16";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>)
|
||||
return "fp8fp32";
|
||||
else
|
||||
static_assert(false);
|
||||
}();
|
||||
@@ -367,22 +376,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
|
||||
if(squant)
|
||||
{
|
||||
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
|
||||
scale_p = p_dtype_max / range_p;
|
||||
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
|
||||
}
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
@@ -528,7 +521,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
|
||||
? std::array<ck_tile::index_t, 1>{batch}
|
||||
: std::array<ck_tile::index_t, 1>{1});
|
||||
|
||||
float max_o = 5.0;
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -576,32 +569,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3")
|
||||
{
|
||||
// suitable for fp8 quantization
|
||||
if(!squant)
|
||||
{
|
||||
std::cerr << "init method " << init_method << " can not be used without quantization"
|
||||
<< std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, q_dtype_max, next_seed()}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, k_dtype_max, next_seed()}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, k_dtype_max, next_seed()}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, v_dtype_max, next_seed()}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, v_dtype_max, next_seed()}(vnew_host);
|
||||
|
||||
// bias_fp8 = qscale_bias * bias_fp32
|
||||
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
|
||||
// Assume bias is in [0.f, 1.f] in original fp32
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, qscale_bias, next_seed()}(bias_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unknown value for init argument: " << init_method << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
|
||||
@@ -625,8 +592,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
@@ -650,10 +617,79 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
if(squant)
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
// Q tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::min());
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
|
||||
float scale = q_dtype_max / max_value;
|
||||
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<QDataType>(val * scale);
|
||||
});
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
// K tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::min());
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
float scale = k_dtype_max / max_value;
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<KDataType>(val * scale);
|
||||
});
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
// V tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::min());
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
|
||||
float scale = k_dtype_max / max_value;
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<VDataType>(val * scale);
|
||||
});
|
||||
|
||||
scale_o = (1.0 / p_dtype_max) / scale;
|
||||
}
|
||||
|
||||
scale_p = p_dtype_max;
|
||||
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
{
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
scale_o = scale_o * o_dtype_max / max_o;
|
||||
}
|
||||
}
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
@@ -1103,7 +1139,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
constexpr bool supports_squant = std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
|
||||
constexpr bool supports_squant = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>;
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(supports_squant)
|
||||
@@ -1113,9 +1151,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(supports_squant)
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_squant)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else if constexpr(supports_squant)
|
||||
return ck_tile::scales{scale_o};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
@@ -94,7 +94,30 @@ run_fp8_tests() {
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
run_exe -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
@@ -117,7 +140,9 @@ run_fp16_appendkv_tests() {
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
# run_fp8_tests
|
||||
run_fp8_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
|
||||
if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
|
||||
@@ -1446,29 +1446,35 @@ struct FmhaFwdKernel
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{kargs.scale_o});
|
||||
else
|
||||
return ck_tile::scales{kargs.scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
// SGrad and Bias use the same address in LDS, finish loading ds on the previous
|
||||
// iteration to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
@@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
// SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to
|
||||
// reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
@@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
// Finish loading bias_s to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
@@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
// SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when
|
||||
// bias is not used, loading ds in the hot loop to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
|
||||
// LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
|
||||
@@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
|
||||
// LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
|
||||
@@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
|
||||
constexpr index_t smem_size_stage0_1 = smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot +
|
||||
smem_size_do + smem_size_lse + smem_size_d +
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ with open('$TEST_FILE', 'r') as f:
|
||||
if tests:
|
||||
# Extract just the filename after the last '/'
|
||||
clean_tests = [os.path.basename(test) for test in tests]
|
||||
print('ctest -R \"' + '|'.join(clean_tests) + '\"')
|
||||
print('ctest --output-on-failure -R \"' + '|'.join(clean_tests) + '\"')
|
||||
else:
|
||||
print('# No tests to run')
|
||||
")
|
||||
@@ -57,5 +57,3 @@ with open('$TEST_FILE', 'r') as f:
|
||||
echo "$command"
|
||||
|
||||
eval "$command"
|
||||
|
||||
|
||||
|
||||
@@ -32,9 +32,6 @@ const ck_tile::stream_config stream_config{
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
// range_q, range_k, range_v, range_p, range_o, squant
|
||||
#define QUANT_ARGS 1, 1, 1, 1, 1, squant
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
stream_config
|
||||
@@ -117,7 +114,7 @@ TEST_P(AllLong, Test)
|
||||
1024, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -179,7 +176,7 @@ TEST_P(HDimPadding, Test)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -236,7 +233,7 @@ TEST_P(ElementwiseBias, Test)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -292,7 +289,7 @@ TEST_P(Alibi, Test)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -350,7 +347,7 @@ TEST_P(Dropout, Test)
|
||||
drop_offset, // drop_offset
|
||||
drop_prefs, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -410,7 +407,7 @@ TEST_P(PagedKV, Test)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
true, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -476,7 +473,7 @@ TEST_P(SplitKV, Test)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
true, // is_rotary_interleaved
|
||||
num_splits, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -548,7 +545,7 @@ TEST_P(AppendKV, Test)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
false, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
@@ -618,7 +615,7 @@ TEST_P(AppendKVRoPE, Test)
|
||||
0, // drop_offset
|
||||
false, // drop_prefs
|
||||
mask_str, // mask_str
|
||||
QUANT_ARGS,
|
||||
squant,
|
||||
is_rotary_interleaved, // is_rotary_interleaved
|
||||
1, // num_splits
|
||||
COMMON_ARGS);
|
||||
|
||||
@@ -17,22 +17,21 @@ using DataTypeConfig = FmhaFwdFp8;
|
||||
// instances are added), however the corresponding tests are not disabled (they will be skipped)
|
||||
// in case such instances will be added in the future.
|
||||
|
||||
const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1});
|
||||
|
||||
const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1});
|
||||
|
||||
const auto AppendKVHDimValues =
|
||||
Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1});
|
||||
const auto AppendKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1});
|
||||
|
||||
// There are no fp8 instances with seqlen padding (mode_enum::group requires it)
|
||||
const auto ModeValues = Values(mode_enum::batch);
|
||||
|
||||
const auto IsVRowmajorValues = Values(false);
|
||||
|
||||
const bool squant = true;
|
||||
const std::string init_method = "ufq";
|
||||
const auto squant = true;
|
||||
const std::string init_method = "uf";
|
||||
const bool def_lse = false;
|
||||
const bool def_is_v_rowmajor = false;
|
||||
const bool def_is_v_rowmajor = true;
|
||||
|
||||
int adjust_seqlen(int seqlen)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user