mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c85c272c39
commit
2312eef6c3
@@ -9,6 +9,8 @@ FWD_DTYPE_MAP = {
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32",
|
||||
"mxfp8": "FmhaFwdMxFp8",
|
||||
"mxfp4": "FmhaFwdMxFp4",
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
|
||||
@@ -79,6 +81,7 @@ QSCALE_MAP = {
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
|
||||
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
@@ -86,6 +89,7 @@ QSCALE_CHECK_MAP = {
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
"kv_blockscale": "quant_scale_enum::kv_blockscale",
|
||||
"mx": "quant_scale_enum::mx",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -38,6 +38,8 @@ DTYPE_BITS = {
|
||||
"fp8bf16": 8,
|
||||
"fp8fp32": 8,
|
||||
"bf8": 8,
|
||||
"mxfp8": 8,
|
||||
"mxfp4": 4,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
@@ -836,7 +838,8 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
def check_hdim_tile(
|
||||
problem_ctx: ProblemContext, kernel_ctx: KernelContext
|
||||
) -> bool:
|
||||
if problem_ctx.dtype != "fp32":
|
||||
# FIX: too confusing that it has to know about mx types
|
||||
if problem_ctx.dtype not in ("fp32", "mxfp8", "mxfp4"):
|
||||
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
|
||||
if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and (
|
||||
(
|
||||
@@ -966,8 +969,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
return {
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype={dtype}")
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@@ -1035,9 +1036,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
pass
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -1046,6 +1044,17 @@ class KernelComponentFactoryGfx950(
|
||||
):
|
||||
arch = ArchTrait("gfx950")
|
||||
|
||||
_DT_MXFP8 = ("mxfp8",)
|
||||
_DT_MXFP4 = ("mxfp4",)
|
||||
|
||||
@classmethod
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return (
|
||||
KernelComponentFactoryGfx9.supported_dtypes()
|
||||
+ cls._DT_MXFP8
|
||||
+ cls._DT_MXFP4
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
@@ -1054,6 +1063,18 @@ class KernelComponentFactoryGfx950(
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].append(
|
||||
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
|
||||
elif dtype in cls._DT_MXFP8:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in cls._DT_MXFP4:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
|
||||
} # fmt: skip
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -1091,6 +1112,19 @@ class KernelComponentFactoryGfx950(
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
|
||||
# no need dropout kernels
|
||||
lse = "t"
|
||||
dropout = "f"
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["f"],
|
||||
["mx"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
|
||||
@@ -48,8 +48,12 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("qscale",
|
||||
"n",
|
||||
"n or 0, no scale\n"
|
||||
"pt or 1, per-tensor scale\n")
|
||||
"quant scale:\n"
|
||||
" n or 0, no scale\n"
|
||||
" pt or 1, per-tensor scale\n"
|
||||
" bs or 2, block scale\n"
|
||||
" kvbs or 3, Q per-tensor, K/V per-page block scale\n"
|
||||
" mx or 4, microscaling (exclusively for data types like mxfp8 and mxfp4)")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
@@ -61,7 +65,7 @@ auto create_args(int argc, char* argv[])
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8")
|
||||
.insert("prec", "fp16", "data type: fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
@@ -231,6 +235,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
@@ -239,6 +247,14 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "mxfp8")
|
||||
{
|
||||
return run<FmhaFwdMxFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "mxfp4")
|
||||
{
|
||||
return run<FmhaFwdMxFp4>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -50,6 +50,14 @@ struct FmhaFwdFp8Fp32
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdMxFp8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdMxFp4
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
|
||||
@@ -165,6 +173,54 @@ struct FmhaFwdTypeConfig<FmhaFwdFp8Fp32>
|
||||
using ODataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdMxFp8>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = float;
|
||||
|
||||
using QScaleDataType = ck_tile::e8m0_t;
|
||||
using KScaleDataType = ck_tile::e8m0_t;
|
||||
using VScaleDataType = ck_tile::e8m0_t;
|
||||
using PScaleDataType = ck_tile::e8m0_t;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = 32;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = 32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdMxFp4>
|
||||
{
|
||||
using QDataType = ck_tile::pk_fp4_t;
|
||||
using KDataType = ck_tile::pk_fp4_t;
|
||||
using VDataType = ck_tile::pk_fp4_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::pk_fp4_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = float;
|
||||
|
||||
using QScaleDataType = ck_tile::e8m0_t;
|
||||
using KScaleDataType = ck_tile::e8m0_t;
|
||||
using VScaleDataType = ck_tile::e8m0_t;
|
||||
using PScaleDataType = ck_tile::e8m0_t;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = 32;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = 32;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
@@ -232,6 +288,7 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* block_scale_seqstart_q_ptr;
|
||||
const void* block_scale_seqstart_k_ptr;
|
||||
const void* seqstart_v_scale_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
@@ -252,6 +309,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_q_descale;
|
||||
ck_tile::index_t stride_k_descale;
|
||||
ck_tile::index_t stride_v_descale;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
@@ -635,6 +695,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqlen_k_ptr,
|
||||
args.block_scale_seqstart_q_ptr,
|
||||
args.block_scale_seqstart_k_ptr,
|
||||
args.seqstart_v_scale_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -647,6 +708,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.stride_q_descale,
|
||||
args.stride_k_descale,
|
||||
args.stride_v_descale,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
@@ -697,6 +761,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.stride_q_descale,
|
||||
args.stride_k_descale,
|
||||
args.stride_v_descale,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
|
||||
@@ -84,6 +84,22 @@ auto get_elimit<FmhaFwdFp8Fp32>(std::string /*init_method*/)
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdMxFp8>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.8e-1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdMxFp4>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-1;
|
||||
double atol = 2.6e-1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
||||
{
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
@@ -171,6 +187,28 @@ void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataTy
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TypeConfig, bool IsMx>
|
||||
struct ScalesConfig
|
||||
{
|
||||
using QScaleDataType = float;
|
||||
using KScaleDataType = float;
|
||||
using VScaleDataType = float;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = 1;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = 1;
|
||||
};
|
||||
|
||||
template <typename TypeConfig>
|
||||
struct ScalesConfig<TypeConfig, true>
|
||||
{
|
||||
using QScaleDataType = typename TypeConfig::QScaleDataType;
|
||||
using KScaleDataType = typename TypeConfig::KScaleDataType;
|
||||
using VScaleDataType = typename TypeConfig::VScaleDataType;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = TypeConfig::kQKScaleGranularity;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = TypeConfig::kVScaleGranularity;
|
||||
};
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
@@ -210,6 +248,31 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
constexpr bool is_mx = ck_tile::is_any_of<DataTypeConfig, FmhaFwdMxFp8, FmhaFwdMxFp4>::value;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
using VDataType = typename TypeConfig::VDataType;
|
||||
using BiasDataType = typename TypeConfig::BiasDataType;
|
||||
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
|
||||
using LSEDataType = typename TypeConfig::LSEDataType;
|
||||
using SaccDataType = typename TypeConfig::SaccDataType;
|
||||
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
|
||||
using PDataType = std::conditional_t<is_mx, float, typename TypeConfig::PDataType>;
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
|
||||
using QScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::QScaleDataType;
|
||||
using KScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::KScaleDataType;
|
||||
using VScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::VScaleDataType;
|
||||
|
||||
constexpr ck_tile::index_t kQKScaleGranularity =
|
||||
ScalesConfig<TypeConfig, is_mx>::kQKScaleGranularity;
|
||||
constexpr ck_tile::index_t kVScaleGranularity =
|
||||
ScalesConfig<TypeConfig, is_mx>::kVScaleGranularity;
|
||||
|
||||
// Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the
|
||||
// compute block size
|
||||
constexpr ck_tile::index_t block_scale_size_q_ = 128;
|
||||
@@ -230,6 +293,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
return "fp8bf16";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>)
|
||||
return "fp8fp32";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdMxFp8>)
|
||||
return "mxfp8";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdMxFp4>)
|
||||
return "mxfp4";
|
||||
else
|
||||
static_assert(false);
|
||||
}();
|
||||
@@ -242,6 +309,26 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
|
||||
if(hdim_q % ck_tile::numeric_traits<QDataType>::PackedSize != 0)
|
||||
{
|
||||
std::cerr << "hdim_q is made even for fp4 Q data type" << std::endl;
|
||||
hdim_q =
|
||||
ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits<QDataType>::PackedSize);
|
||||
}
|
||||
if(hdim_q % ck_tile::numeric_traits<KDataType>::PackedSize != 0)
|
||||
{
|
||||
std::cerr << "hdim_q is made even for fp4 K data type" << std::endl;
|
||||
hdim_q =
|
||||
ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits<KDataType>::PackedSize);
|
||||
}
|
||||
if(is_mx && !seqlen_kpads.empty() && seqlen_kpads[0] > 0)
|
||||
{
|
||||
std::cerr
|
||||
<< "seqlen_kpads is not supported with MX types. ignoring the 'seqlen_kpads' option"
|
||||
<< std::endl;
|
||||
seqlen_kpads = {-1};
|
||||
}
|
||||
|
||||
std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}());
|
||||
auto next_seed = [&random_engine]() { return static_cast<unsigned int>(random_engine()); };
|
||||
|
||||
@@ -371,6 +458,20 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
|
||||
need_append_kvcache,
|
||||
random_engine);
|
||||
|
||||
if(ck_tile::numeric_traits<VDataType>::PackedSize != 0)
|
||||
{
|
||||
// Ensure that all seqlens are even if V has packed data type
|
||||
for(auto& s : seqlen_ks)
|
||||
{
|
||||
s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits<VDataType>::PackedSize);
|
||||
}
|
||||
for(auto& s : kv_eff_lens_per_batch)
|
||||
{
|
||||
s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits<VDataType>::PackedSize);
|
||||
}
|
||||
}
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb])
|
||||
@@ -410,6 +511,22 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
quant_scale_info qscale = quant_scale_info::decode(qscale_str);
|
||||
|
||||
if(is_mx && qscale.type != quant_scale_enum::mx)
|
||||
{
|
||||
std::cerr << "The value of qscale_str must be 'mx' for MX data types" << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
else if(!is_mx && qscale.type == quant_scale_enum::mx)
|
||||
{
|
||||
std::cerr << "The value of qscale_str cannot be 'mx' for non-MX data types" << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
if(is_mx && is_v_rowmajor)
|
||||
{
|
||||
std::cerr << "The value of is_v_rowmajor must be 'false' for MX data types" << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
@@ -458,28 +575,16 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
calculate_cumulative(kv_eff_lens_per_batch, cukv_cum);
|
||||
}
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
using VDataType = typename TypeConfig::VDataType;
|
||||
using BiasDataType = typename TypeConfig::BiasDataType;
|
||||
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
|
||||
using LSEDataType = typename TypeConfig::LSEDataType;
|
||||
using SaccDataType = typename TypeConfig::SaccDataType;
|
||||
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
|
||||
using PDataType = typename TypeConfig::PDataType;
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
size_t i_block_scale_q = 0;
|
||||
size_t i_block_scale_k = 0;
|
||||
int32_t i_block_scale_q = 0;
|
||||
int32_t i_block_scale_k = 0;
|
||||
int32_t i_seqstart_v_scale = 0;
|
||||
std::vector<int32_t> block_scale_seqstart_q_host = {0};
|
||||
std::vector<int32_t> block_scale_seqstart_k_host = {0};
|
||||
std::vector<int32_t> seqstart_v_scale_host = {0};
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
{
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
@@ -496,18 +601,31 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
|
||||
i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
|
||||
block_scale_seqstart_q_host.push_back(i_block_scale_q);
|
||||
block_scale_seqstart_k_host.push_back(i_block_scale_k);
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
|
||||
i_block_scale_k +=
|
||||
ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
|
||||
block_scale_seqstart_q_host.push_back(i_block_scale_q);
|
||||
block_scale_seqstart_k_host.push_back(i_block_scale_k);
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::mx)
|
||||
{
|
||||
i_seqstart_v_scale +=
|
||||
ck_tile::integer_divide_ceil(real_seqlen_k, kVScaleGranularity);
|
||||
seqstart_v_scale_host.push_back(i_seqstart_v_scale);
|
||||
}
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q /
|
||||
ck_tile::numeric_traits<QDataType>::PackedSize +
|
||||
sizeof(ODataType) * real_seqlen_q * hdim_v);
|
||||
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q +
|
||||
sizeof(VDataType) * hdim_v * real_seqlen_k);
|
||||
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q /
|
||||
ck_tile::numeric_traits<KDataType>::PackedSize +
|
||||
sizeof(VDataType) * hdim_v * real_seqlen_k /
|
||||
ck_tile::numeric_traits<VDataType>::PackedSize);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -620,19 +738,30 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// TODO - change the tensor length for different quant scale
|
||||
ck_tile::HostTensor<float> q_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> k_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> v_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
const ck_tile::index_t hdim_q_scale = ck_tile::integer_divide_ceil(hdim_q, kQKScaleGranularity);
|
||||
const ck_tile::index_t shape_seqlen_v_scale = seqstart_v_scale_host.back();
|
||||
|
||||
ck_tile::HostTensor<QScaleDataType> q_descale_host({1});
|
||||
ck_tile::HostTensor<KScaleDataType> k_descale_host({1});
|
||||
ck_tile::HostTensor<VScaleDataType> v_descale_host({1});
|
||||
if constexpr(is_mx)
|
||||
{
|
||||
q_descale_host = ck_tile::HostTensor<QScaleDataType>(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q_scale));
|
||||
k_descale_host = ck_tile::HostTensor<KScaleDataType>(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q_scale));
|
||||
v_descale_host = ck_tile::HostTensor<VScaleDataType>(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_v_scale));
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
q_descale_host = ck_tile::HostTensor<QScaleDataType>(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q});
|
||||
k_descale_host = ck_tile::HostTensor<KScaleDataType>(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv});
|
||||
v_descale_host = ck_tile::HostTensor<VScaleDataType>(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv});
|
||||
}
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
@@ -704,10 +833,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
else if(init_method == "3")
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float bias_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<BiasDataType>::max());
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
|
||||
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, next_seed()}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(k_host);
|
||||
@@ -716,8 +844,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(
|
||||
vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{
|
||||
-bias_dtype_max, bias_dtype_max, next_seed()}(bias_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, next_seed()}(bias_host);
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
@@ -737,7 +864,41 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
if(qscale.type == quant_scale_enum::pertensor)
|
||||
if constexpr(is_mx)
|
||||
{
|
||||
auto gen_scales = [&](auto& scales, auto data, float range) {
|
||||
using DataType = decltype(data);
|
||||
using ScaleType = ck_tile::remove_cvref_t<decltype(*scales.begin())>;
|
||||
if constexpr(std::is_same_v<ScaleType, ck_tile::e8m0_t>)
|
||||
{
|
||||
const float base =
|
||||
-std::log2(ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max()));
|
||||
// e8m0_t is basically an exponent of float32
|
||||
// When scales are applied to tensor values, value * exp2(base - range) is around
|
||||
// 0.125 and value * exp2(base + range) is around 8 for all types (fp8/bf8/fp4)
|
||||
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{
|
||||
base - range, base + range, next_seed()}(pow2);
|
||||
scales.ForEach([&](auto& self, const auto& i) {
|
||||
self(i) = ck_tile::type_convert<ScaleType>(std::exp2(pow2(i)));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false);
|
||||
}
|
||||
};
|
||||
gen_scales(q_descale_host, QDataType{}, 3);
|
||||
gen_scales(k_descale_host, KDataType{}, 3);
|
||||
// When P is fp4, only 8 values (0, 0.5, 1, 1.5, 2, 3, 4, 6) are used to quantize P.
|
||||
// Too large V values can create rare error outliers between host (no quantization) and
|
||||
// device ("running" FA softmax + quantization), here we reduce max value by using smaller
|
||||
// range of V scales.
|
||||
gen_scales(v_descale_host,
|
||||
VDataType{},
|
||||
std::is_same_v<typename TypeConfig::PDataType, ck_tile::pk_fp4_t> ? 1 : 3);
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::pertensor)
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
@@ -790,12 +951,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem scale_seqstart_v_buf(seqstart_v_scale_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_q_buf(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k_buf(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
|
||||
? 0
|
||||
: seqstart_q_with_padding_host.size() *
|
||||
@@ -837,9 +999,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
v_descale_buf.ToDevice(v_descale_host.data());
|
||||
block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data());
|
||||
block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
// Keep logical starts in seqstart_k; pass padded K via separate pointer
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
scale_seqstart_v_buf.ToDevice(seqstart_v_scale_host.data());
|
||||
seqstart_q_buf.ToDevice(seqstart_q_host.data());
|
||||
// Keep logical starts in seqstart_k_buf; pass padded K via separate pointer
|
||||
seqstart_k_buf.ToDevice(seqstart_k_host.data());
|
||||
seqstart_q_padded_buf.ToDevice(
|
||||
seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data());
|
||||
seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr
|
||||
@@ -1145,7 +1308,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
if(qscale.type == quant_scale_enum::pertensor)
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
args.q_descale_ptr =
|
||||
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
|
||||
@@ -1172,11 +1341,32 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.block_scale_size_q = block_scale_size_q_;
|
||||
args.block_scale_size_kv = block_scale_size_kv_;
|
||||
}
|
||||
else
|
||||
else if(qscale.type == quant_scale_enum::mx)
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
|
||||
args.stride_q_descale = (i_perm ? hdim_q_scale : nhead * hdim_q_scale);
|
||||
args.stride_k_descale = (i_perm ? hdim_q_scale : nhead_k * hdim_q_scale);
|
||||
args.stride_v_descale =
|
||||
(i_perm ? shape_seqlen_v_scale : nhead_k * shape_seqlen_v_scale);
|
||||
args.nhead_stride_q_descale =
|
||||
(i_perm ? shape_seqlen_q * hdim_q_scale : hdim_q_scale);
|
||||
args.nhead_stride_k_descale =
|
||||
(i_perm ? shape_seqlen_k * hdim_q_scale : hdim_q_scale);
|
||||
args.nhead_stride_v_descale =
|
||||
(i_perm ? hdim_v * shape_seqlen_v_scale : shape_seqlen_v_scale);
|
||||
if(mode == mode_enum::group)
|
||||
{
|
||||
args.seqstart_v_scale_ptr = scale_seqstart_v_buf.GetDeviceBuffer();
|
||||
}
|
||||
else
|
||||
{
|
||||
args.batch_stride_q_descale = (nhead * shape_seqlen_q * hdim_q_scale);
|
||||
args.batch_stride_k_descale = (nhead_k * shape_seqlen_k * hdim_q_scale);
|
||||
args.batch_stride_v_descale = (nhead_k * hdim_v * shape_seqlen_v_scale);
|
||||
}
|
||||
}
|
||||
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
@@ -1207,11 +1397,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.seqstart_q_ptr =
|
||||
has_group_q_padding && !seqstart_q_with_padding_host.empty()
|
||||
? seqstart_q_padded_buf.GetDeviceBuffer()
|
||||
: seqstart_q.GetDeviceBuffer();
|
||||
: seqstart_q_buf.GetDeviceBuffer();
|
||||
args.seqstart_k_ptr =
|
||||
has_group_k_padding && !seqstart_k_with_padding_host.empty()
|
||||
? seqstart_k_padded_buf.GetDeviceBuffer()
|
||||
: seqstart_k.GetDeviceBuffer();
|
||||
: seqstart_k_buf.GetDeviceBuffer();
|
||||
|
||||
// Logical (unpadded) per-sequence lengths, used when padding is enabled
|
||||
args.seqlen_q_ptr =
|
||||
@@ -1272,9 +1462,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.split_stride_o_acc = split_stride_o_acc;
|
||||
|
||||
args.seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
(mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
(mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
@@ -1292,9 +1482,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
|
||||
|
||||
args.seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
(mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
(mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr =
|
||||
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
@@ -1462,11 +1652,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
float scale_p_host = 1.0f;
|
||||
float scale_o_host = 1.0f;
|
||||
|
||||
if(qscale.type == quant_scale_enum::pertensor)
|
||||
if constexpr(!is_mx)
|
||||
{
|
||||
scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0);
|
||||
scale_p_host = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
scale_o_host = v_descale_host(0) / scale_p_host;
|
||||
if(qscale.type == quant_scale_enum::pertensor)
|
||||
{
|
||||
scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0);
|
||||
scale_p_host = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
scale_o_host = v_descale_host(0) / scale_p_host;
|
||||
}
|
||||
}
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
@@ -1680,7 +1873,45 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
#endif
|
||||
|
||||
// reference
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
if constexpr(is_mx)
|
||||
{
|
||||
ck_tile::HostTensor<QScaleDataType> q_descale_host_ref(
|
||||
{nhead, real_seqlen_q, hdim_q_scale});
|
||||
ck_tile::HostTensor<KScaleDataType> k_descale_host_ref(
|
||||
{nhead, real_seqlen_k, hdim_q_scale});
|
||||
|
||||
// clang-format off
|
||||
if(i_perm) q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[1] + query_offset, i[0], i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
if(i_perm) k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto q_host_ref2 = ck_tile::reference_batched_mx_descale<QDataType,
|
||||
QScaleDataType,
|
||||
SaccDataType,
|
||||
SaccDataType>(
|
||||
q_host_ref, q_descale_host_ref, kQKScaleGranularity);
|
||||
auto k_host_ref2 = ck_tile::reference_batched_mx_descale<KDataType,
|
||||
KScaleDataType,
|
||||
SaccDataType,
|
||||
SaccDataType>(
|
||||
k_host_ref, k_descale_host_ref, kQKScaleGranularity);
|
||||
|
||||
ck_tile::reference_batched_gemm<SaccDataType,
|
||||
SaccDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType>(q_host_ref2,
|
||||
k_host_ref2,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s_host));
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t q_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb];
|
||||
@@ -1881,6 +2112,32 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
}
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
// Use smaller rtol/atol as LSE is computed and stored in fp32, so there is no
|
||||
// precision loss due to conversion
|
||||
bool cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
}
|
||||
}
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
@@ -1907,13 +2164,43 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
randval_host_ref,
|
||||
"DROPOUT RANDVAL Error: Incorrect results!");
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
if constexpr(is_mx)
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_v_scale =
|
||||
seqstart_v_scale_host[wb + 1] - seqstart_v_scale_host[wb];
|
||||
const ck_tile::index_t v_scale_offset =
|
||||
mode == mode_enum::batch ? 0 : seqstart_v_scale_host[wb];
|
||||
|
||||
ck_tile::HostTensor<VScaleDataType> v_descale_host_ref(
|
||||
{nhead, hdim_v, real_seqlen_v_scale});
|
||||
|
||||
// clang-format off
|
||||
if(i_perm) v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[0] / nr, i[1], i[2] + v_scale_offset); });
|
||||
else v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[1], i[0] / nr, i[2] + v_scale_offset); });
|
||||
// clang-format on
|
||||
|
||||
auto v_host_ref2 = ck_tile::reference_batched_mx_descale<VDataType,
|
||||
VScaleDataType,
|
||||
OaccDataType,
|
||||
OaccDataType>(
|
||||
v_host_ref, v_descale_host_ref, kVScaleGranularity);
|
||||
|
||||
// P is not quantized and then dequantized here (PDataType = float).
|
||||
// On host softmax is computed for the whole row of S, while on device FA computes
|
||||
// softmax and quantizes it in blocks of N0 values. Quantization on host would make
|
||||
// reference results **less** precise than the device results for large seqlen_k!
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, OaccDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref2,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t v_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
||||
@@ -1969,35 +2256,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
<< std::endl
|
||||
<< "\tquery_offset used: " << query_offset << std::endl
|
||||
<< "\tkey_offset used: " << key_offset << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if(lse)
|
||||
if(!pass)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2006,33 +2269,37 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if(json)
|
||||
{
|
||||
dump_fmha_fwd_json_results(
|
||||
*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale
|
||||
? "no_scale"
|
||||
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
const std::string qscale_name =
|
||||
(qscale.type == quant_scale_enum::no_scale ? "no_scale"
|
||||
: qscale.type == quant_scale_enum::pertensor ? "pertensor"
|
||||
: qscale.type == quant_scale_enum::blockscale ? "blockscale"
|
||||
: qscale.type == quant_scale_enum::kv_blockscale ? "kv_blockscale"
|
||||
: qscale.type == quant_scale_enum::mx ? "mx"
|
||||
: "unknown");
|
||||
dump_fmha_fwd_json_results(*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale_name,
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass ? fwd_result::success : fwd_result::failure;
|
||||
|
||||
@@ -18,6 +18,7 @@ enum class quant_scale_enum
|
||||
pertensor = 1,
|
||||
blockscale = 2,
|
||||
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
||||
mx = 4, // Microscaling (MX)
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -34,6 +35,8 @@ struct quant_scale_info
|
||||
os << "bs";
|
||||
else if(type == quant_scale_enum::kv_blockscale)
|
||||
os << "kvbs";
|
||||
else if(type == quant_scale_enum::mx)
|
||||
os << "mx";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -55,6 +58,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::kv_blockscale;
|
||||
}
|
||||
else if(str == "mx" || str == "4")
|
||||
{
|
||||
info.type = quant_scale_enum::mx;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
Reference in New Issue
Block a user