[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)

[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on
 gfx950 (#4368)

## Motivation

Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline

## Technical Details

The microscaling is used when quant scale mode is
`BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are
fp8/bf8/fp4.

Supported features:
* only "qr" pipeline is implemented
* hdim 128 and 256 (smaller hdim are not possible due to restrictions of
"qr" pipeline, but they can be computed using instances with padding)
 * both 32x32x64 and 16x16x128 scale MFMAs are supported
 * Q and K scales are applied in hdim, V scales - in seqlen dimension
 * column-major V only
 * batch and group mode
 * bias, Alibi (tested but no instances by default, just like fp8)
 * masking etc.

Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008

## Test Plan

```
ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8
ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4
```

## Test Result

The tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Anton Gorenko
2026-03-11 10:00:52 +00:00
committed by assistant-librarian[bot]
parent c85c272c39
commit 2312eef6c3
29 changed files with 2167 additions and 356 deletions

View File

@@ -9,6 +9,8 @@ FWD_DTYPE_MAP = {
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32",
"mxfp8": "FmhaFwdMxFp8",
"mxfp4": "FmhaFwdMxFp4",
}
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
@@ -79,6 +81,7 @@ QSCALE_MAP = {
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
"mx": "ck_tile::BlockAttentionQuantScaleEnum::MX",
}
QSCALE_CHECK_MAP = {
@@ -86,6 +89,7 @@ QSCALE_CHECK_MAP = {
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
"mx": "quant_scale_enum::mx",
}
BIAS_MAP = {

View File

@@ -38,6 +38,8 @@ DTYPE_BITS = {
"fp8bf16": 8,
"fp8fp32": 8,
"bf8": 8,
"mxfp8": 8,
"mxfp4": 4,
}
K0_MAX_SUBMAX_MAP = {
@@ -836,7 +838,8 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
def check_hdim_tile(
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
if problem_ctx.dtype != "fp32":
# FIX: too confusing that it has to know about mx types
if problem_ctx.dtype not in ("fp32", "mxfp8", "mxfp4"):
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and (
(
@@ -966,8 +969,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
return {
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
else:
raise ValueError(f"unsupported dtype={dtype}")
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@@ -1035,9 +1036,6 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
pass
return pipelines
@@ -1046,6 +1044,17 @@ class KernelComponentFactoryGfx950(
):
arch = ArchTrait("gfx950")
_DT_MXFP8 = ("mxfp8",)
_DT_MXFP4 = ("mxfp4",)
@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return (
KernelComponentFactoryGfx9.supported_dtypes()
+ cls._DT_MXFP8
+ cls._DT_MXFP4
)
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
@@ -1054,6 +1063,18 @@ class KernelComponentFactoryGfx950(
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
elif dtype in cls._DT_MXFP8:
return {
# bm0, bn0, bk0, bn1, bk1,
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
} # fmt: skip
elif dtype in cls._DT_MXFP4:
return {
# bm0, bn0, bk0, bn1, bk1,
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
} # fmt: skip
return result
@classmethod
@@ -1091,6 +1112,19 @@ class KernelComponentFactoryGfx950(
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
# no need dropout kernels
lse = "t"
dropout = "f"
for logits, qscale, mask, bias, sink in itertools.product(
["f"],
["mx"],
get_mask_map(mask_impl).keys(),
["no"],
["f", "t"],
):
pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
return pipelines

View File

@@ -48,8 +48,12 @@ auto create_args(int argc, char* argv[])
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("qscale",
"n",
"n or 0, no scale\n"
"pt or 1, per-tensor scale\n")
"quant scale:\n"
" n or 0, no scale\n"
" pt or 1, per-tensor scale\n"
" bs or 2, block scale\n"
" kvbs or 3, Q per-tensor, K/V per-page block scale\n"
" mx or 4, microscaling (exclusively for data types like mxfp8 and mxfp4)")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("iperm",
"1",
@@ -61,7 +65,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type: fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
@@ -231,6 +235,10 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
@@ -239,6 +247,14 @@ int main(int argc, char* argv[])
{
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "mxfp8")
{
return run<FmhaFwdMxFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "mxfp4")
{
return run<FmhaFwdMxFp4>(arg_parser) == fwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}

View File

@@ -50,6 +50,14 @@ struct FmhaFwdFp8Fp32
{
};
struct FmhaFwdMxFp8
{
};
struct FmhaFwdMxFp4
{
};
template <typename DataType>
struct FmhaFwdTypeConfig;
@@ -165,6 +173,54 @@ struct FmhaFwdTypeConfig<FmhaFwdFp8Fp32>
using ODataType = float;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdMxFp8>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;
using QScaleDataType = ck_tile::e8m0_t;
using KScaleDataType = ck_tile::e8m0_t;
using VScaleDataType = ck_tile::e8m0_t;
using PScaleDataType = ck_tile::e8m0_t;
static constexpr ck_tile::index_t kQKScaleGranularity = 32;
static constexpr ck_tile::index_t kVScaleGranularity = 32;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdMxFp4>
{
using QDataType = ck_tile::pk_fp4_t;
using KDataType = ck_tile::pk_fp4_t;
using VDataType = ck_tile::pk_fp4_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::pk_fp4_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;
using QScaleDataType = ck_tile::e8m0_t;
using KScaleDataType = ck_tile::e8m0_t;
using VScaleDataType = ck_tile::e8m0_t;
using PScaleDataType = ck_tile::e8m0_t;
static constexpr ck_tile::index_t kQKScaleGranularity = 32;
static constexpr ck_tile::index_t kVScaleGranularity = 32;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
@@ -232,6 +288,7 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* block_scale_seqstart_q_ptr;
const void* block_scale_seqstart_k_ptr;
const void* seqstart_v_scale_ptr;
const void* sink_ptr;
ck_tile::index_t seqlen_q;
@@ -252,6 +309,9 @@ struct fmha_fwd_args
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t stride_q_descale;
ck_tile::index_t stride_k_descale;
ck_tile::index_t stride_v_descale;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
@@ -635,6 +695,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqlen_k_ptr,
args.block_scale_seqstart_q_ptr,
args.block_scale_seqstart_k_ptr,
args.seqstart_v_scale_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -647,6 +708,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_o,
args.stride_q_descale,
args.stride_k_descale,
args.stride_v_descale,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
@@ -697,6 +761,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_o,
args.stride_q_descale,
args.stride_k_descale,
args.stride_v_descale,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,

View File

@@ -84,6 +84,22 @@ auto get_elimit<FmhaFwdFp8Fp32>(std::string /*init_method*/)
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaFwdMxFp8>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1.8e-1;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaFwdMxFp4>(std::string /*init_method*/)
{
double rtol = 1e-1;
double atol = 2.6e-1;
return ck_tile::make_tuple(rtol, atol);
}
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
@@ -171,6 +187,28 @@ void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataTy
}
}
template <typename TypeConfig, bool IsMx>
struct ScalesConfig
{
using QScaleDataType = float;
using KScaleDataType = float;
using VScaleDataType = float;
static constexpr ck_tile::index_t kQKScaleGranularity = 1;
static constexpr ck_tile::index_t kVScaleGranularity = 1;
};
template <typename TypeConfig>
struct ScalesConfig<TypeConfig, true>
{
using QScaleDataType = typename TypeConfig::QScaleDataType;
using KScaleDataType = typename TypeConfig::KScaleDataType;
using VScaleDataType = typename TypeConfig::VScaleDataType;
static constexpr ck_tile::index_t kQKScaleGranularity = TypeConfig::kQKScaleGranularity;
static constexpr ck_tile::index_t kVScaleGranularity = TypeConfig::kVScaleGranularity;
};
template <typename DataTypeConfig>
fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::index_t batch,
@@ -210,6 +248,31 @@ fwd_result fmha_fwd_run(mode_enum mode,
const ck_tile::stream_config& stream_config,
std::optional<std::string> json = std::nullopt)
{
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
constexpr bool is_mx = ck_tile::is_any_of<DataTypeConfig, FmhaFwdMxFp8, FmhaFwdMxFp4>::value;
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
using VDataType = typename TypeConfig::VDataType;
using BiasDataType = typename TypeConfig::BiasDataType;
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
using LSEDataType = typename TypeConfig::LSEDataType;
using SaccDataType = typename TypeConfig::SaccDataType;
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
using PDataType = std::conditional_t<is_mx, float, typename TypeConfig::PDataType>;
using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType;
using QScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::QScaleDataType;
using KScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::KScaleDataType;
using VScaleDataType = typename ScalesConfig<TypeConfig, is_mx>::VScaleDataType;
constexpr ck_tile::index_t kQKScaleGranularity =
ScalesConfig<TypeConfig, is_mx>::kQKScaleGranularity;
constexpr ck_tile::index_t kVScaleGranularity =
ScalesConfig<TypeConfig, is_mx>::kVScaleGranularity;
// Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the
// compute block size
constexpr ck_tile::index_t block_scale_size_q_ = 128;
@@ -230,6 +293,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
return "fp8bf16";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>)
return "fp8fp32";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdMxFp8>)
return "mxfp8";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdMxFp4>)
return "mxfp4";
else
static_assert(false);
}();
@@ -242,6 +309,26 @@ fwd_result fmha_fwd_run(mode_enum mode,
return fwd_result::invalid_args;
}
if(hdim_q % ck_tile::numeric_traits<QDataType>::PackedSize != 0)
{
std::cerr << "hdim_q is made even for fp4 Q data type" << std::endl;
hdim_q =
ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits<QDataType>::PackedSize);
}
if(hdim_q % ck_tile::numeric_traits<KDataType>::PackedSize != 0)
{
std::cerr << "hdim_q is made even for fp4 K data type" << std::endl;
hdim_q =
ck_tile::integer_least_multiple(hdim_q, ck_tile::numeric_traits<KDataType>::PackedSize);
}
if(is_mx && !seqlen_kpads.empty() && seqlen_kpads[0] > 0)
{
std::cerr
<< "seqlen_kpads is not supported with MX types. ignoring the 'seqlen_kpads' option"
<< std::endl;
seqlen_kpads = {-1};
}
std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}());
auto next_seed = [&random_engine]() { return static_cast<unsigned int>(random_engine()); };
@@ -371,6 +458,20 @@ fwd_result fmha_fwd_run(mode_enum mode,
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
need_append_kvcache,
random_engine);
if(ck_tile::numeric_traits<VDataType>::PackedSize != 0)
{
// Ensure that all seqlens are even if V has packed data type
for(auto& s : seqlen_ks)
{
s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits<VDataType>::PackedSize);
}
for(auto& s : kv_eff_lens_per_batch)
{
s = ck_tile::integer_least_multiple(s, ck_tile::numeric_traits<VDataType>::PackedSize);
}
}
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb])
@@ -410,6 +511,22 @@ fwd_result fmha_fwd_run(mode_enum mode,
quant_scale_info qscale = quant_scale_info::decode(qscale_str);
if(is_mx && qscale.type != quant_scale_enum::mx)
{
std::cerr << "The value of qscale_str must be 'mx' for MX data types" << std::endl;
return fwd_result::invalid_args;
}
else if(!is_mx && qscale.type == quant_scale_enum::mx)
{
std::cerr << "The value of qscale_str cannot be 'mx' for non-MX data types" << std::endl;
return fwd_result::invalid_args;
}
if(is_mx && is_v_rowmajor)
{
std::cerr << "The value of is_v_rowmajor must be 'false' for MX data types" << std::endl;
return fwd_result::invalid_args;
}
if(p_drop < 0.0f || p_drop > 1.0f)
{
std::cerr << "The value of p_drop should be 0~1" << std::endl;
@@ -458,28 +575,16 @@ fwd_result fmha_fwd_run(mode_enum mode,
calculate_cumulative(kv_eff_lens_per_batch, cukv_cum);
}
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
using VDataType = typename TypeConfig::VDataType;
using BiasDataType = typename TypeConfig::BiasDataType;
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
using LSEDataType = typename TypeConfig::LSEDataType;
using SaccDataType = typename TypeConfig::SaccDataType;
using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType;
using PDataType = typename TypeConfig::PDataType;
using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType;
// accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
size_t i_block_scale_q = 0;
size_t i_block_scale_k = 0;
int32_t i_block_scale_q = 0;
int32_t i_block_scale_k = 0;
int32_t i_seqstart_v_scale = 0;
std::vector<int32_t> block_scale_seqstart_q_host = {0};
std::vector<int32_t> block_scale_seqstart_k_host = {0};
std::vector<int32_t> seqstart_v_scale_host = {0};
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
{
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
@@ -496,18 +601,31 @@ fwd_result fmha_fwd_run(mode_enum mode,
{
max_seqlen_k = real_seqlen_k;
}
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
block_scale_seqstart_q_host.push_back(i_block_scale_q);
block_scale_seqstart_k_host.push_back(i_block_scale_k);
if(qscale.type == quant_scale_enum::blockscale)
{
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
i_block_scale_k +=
ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
block_scale_seqstart_q_host.push_back(i_block_scale_q);
block_scale_seqstart_k_host.push_back(i_block_scale_k);
}
else if(qscale.type == quant_scale_enum::mx)
{
i_seqstart_v_scale +=
ck_tile::integer_divide_ceil(real_seqlen_k, kVScaleGranularity);
seqstart_v_scale_host.push_back(i_seqstart_v_scale);
}
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q /
ck_tile::numeric_traits<QDataType>::PackedSize +
sizeof(ODataType) * real_seqlen_q * hdim_v);
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q +
sizeof(VDataType) * hdim_v * real_seqlen_k);
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q /
ck_tile::numeric_traits<KDataType>::PackedSize +
sizeof(VDataType) * hdim_v * real_seqlen_k /
ck_tile::numeric_traits<VDataType>::PackedSize);
}
}
@@ -620,19 +738,30 @@ fwd_result fmha_fwd_run(mode_enum mode,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// TODO - change the tensor length for different quant scale
ck_tile::HostTensor<float> q_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<float> k_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<float> v_descale_host(
qscale.type == quant_scale_enum::blockscale
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
const ck_tile::index_t hdim_q_scale = ck_tile::integer_divide_ceil(hdim_q, kQKScaleGranularity);
const ck_tile::index_t shape_seqlen_v_scale = seqstart_v_scale_host.back();
ck_tile::HostTensor<QScaleDataType> q_descale_host({1});
ck_tile::HostTensor<KScaleDataType> k_descale_host({1});
ck_tile::HostTensor<VScaleDataType> v_descale_host({1});
if constexpr(is_mx)
{
q_descale_host = ck_tile::HostTensor<QScaleDataType>(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q_scale));
k_descale_host = ck_tile::HostTensor<KScaleDataType>(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q_scale));
v_descale_host = ck_tile::HostTensor<VScaleDataType>(
get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_v_scale));
}
else if(qscale.type == quant_scale_enum::blockscale)
{
q_descale_host = ck_tile::HostTensor<QScaleDataType>(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q});
k_descale_host = ck_tile::HostTensor<KScaleDataType>(
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv});
v_descale_host = ck_tile::HostTensor<VScaleDataType>(
std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv});
}
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
@@ -704,10 +833,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
else if(init_method == "3")
{
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float bias_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<BiasDataType>::max());
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, next_seed()}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, next_seed()}(k_host);
@@ -716,8 +844,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, next_seed()}(
vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{
-bias_dtype_max, bias_dtype_max, next_seed()}(bias_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, next_seed()}(bias_host);
}
if(bias.type == bias_enum::alibi)
{
@@ -737,7 +864,41 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
}
}
if(qscale.type == quant_scale_enum::pertensor)
if constexpr(is_mx)
{
auto gen_scales = [&](auto& scales, auto data, float range) {
using DataType = decltype(data);
using ScaleType = ck_tile::remove_cvref_t<decltype(*scales.begin())>;
if constexpr(std::is_same_v<ScaleType, ck_tile::e8m0_t>)
{
const float base =
-std::log2(ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max()));
// e8m0_t is basically an exponent of float32
// When scales are applied to tensor values, value * exp2(base - range) is around
// 0.125 and value * exp2(base + range) is around 8 for all types (fp8/bf8/fp4)
ck_tile::HostTensor<float> pow2(scales.get_lengths());
ck_tile::FillUniformDistributionIntegerValue<float>{
base - range, base + range, next_seed()}(pow2);
scales.ForEach([&](auto& self, const auto& i) {
self(i) = ck_tile::type_convert<ScaleType>(std::exp2(pow2(i)));
});
}
else
{
static_assert(false);
}
};
gen_scales(q_descale_host, QDataType{}, 3);
gen_scales(k_descale_host, KDataType{}, 3);
// When P is fp4, only 8 values (0, 0.5, 1, 1.5, 2, 3, 4, 6) are used to quantize P.
// Too large V values can create rare error outliers between host (no quantization) and
// device ("running" FA softmax + quantization), here we reduce max value by using smaller
// range of V scales.
gen_scales(v_descale_host,
VDataType{},
std::is_same_v<typename TypeConfig::PDataType, ck_tile::pk_fp4_t> ? 1 : 3);
}
else if(qscale.type == quant_scale_enum::pertensor)
{
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
@@ -790,12 +951,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
sizeof(int32_t));
ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() *
sizeof(int32_t));
ck_tile::DeviceMem scale_seqstart_v_buf(seqstart_v_scale_host.size() * sizeof(int32_t));
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_q_buf(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k_buf(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
? 0
: seqstart_q_with_padding_host.size() *
@@ -837,9 +999,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
v_descale_buf.ToDevice(v_descale_host.data());
block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data());
block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
// Keep logical starts in seqstart_k; pass padded K via separate pointer
seqstart_k.ToDevice(seqstart_k_host.data());
scale_seqstart_v_buf.ToDevice(seqstart_v_scale_host.data());
seqstart_q_buf.ToDevice(seqstart_q_host.data());
// Keep logical starts in seqstart_k_buf; pass padded K via separate pointer
seqstart_k_buf.ToDevice(seqstart_k_host.data());
seqstart_q_padded_buf.ToDevice(
seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data());
seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr
@@ -1145,7 +1308,13 @@ fwd_result fmha_fwd_run(mode_enum mode,
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
if(qscale.type == quant_scale_enum::blockscale)
if(qscale.type == quant_scale_enum::pertensor)
{
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
}
else if(qscale.type == quant_scale_enum::blockscale)
{
args.q_descale_ptr =
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
@@ -1172,11 +1341,32 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.block_scale_size_q = block_scale_size_q_;
args.block_scale_size_kv = block_scale_size_kv_;
}
else
else if(qscale.type == quant_scale_enum::mx)
{
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
args.stride_q_descale = (i_perm ? hdim_q_scale : nhead * hdim_q_scale);
args.stride_k_descale = (i_perm ? hdim_q_scale : nhead_k * hdim_q_scale);
args.stride_v_descale =
(i_perm ? shape_seqlen_v_scale : nhead_k * shape_seqlen_v_scale);
args.nhead_stride_q_descale =
(i_perm ? shape_seqlen_q * hdim_q_scale : hdim_q_scale);
args.nhead_stride_k_descale =
(i_perm ? shape_seqlen_k * hdim_q_scale : hdim_q_scale);
args.nhead_stride_v_descale =
(i_perm ? hdim_v * shape_seqlen_v_scale : shape_seqlen_v_scale);
if(mode == mode_enum::group)
{
args.seqstart_v_scale_ptr = scale_seqstart_v_buf.GetDeviceBuffer();
}
else
{
args.batch_stride_q_descale = (nhead * shape_seqlen_q * hdim_q_scale);
args.batch_stride_k_descale = (nhead_k * shape_seqlen_k * hdim_q_scale);
args.batch_stride_v_descale = (nhead_k * hdim_v * shape_seqlen_v_scale);
}
}
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
@@ -1207,11 +1397,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.seqstart_q_ptr =
has_group_q_padding && !seqstart_q_with_padding_host.empty()
? seqstart_q_padded_buf.GetDeviceBuffer()
: seqstart_q.GetDeviceBuffer();
: seqstart_q_buf.GetDeviceBuffer();
args.seqstart_k_ptr =
has_group_k_padding && !seqstart_k_with_padding_host.empty()
? seqstart_k_padded_buf.GetDeviceBuffer()
: seqstart_k.GetDeviceBuffer();
: seqstart_k_buf.GetDeviceBuffer();
// Logical (unpadded) per-sequence lengths, used when padding is enabled
args.seqlen_q_ptr =
@@ -1272,9 +1462,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.split_stride_o_acc = split_stride_o_acc;
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
(mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
(mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
@@ -1292,9 +1482,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
(mode == mode_enum::group ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
(mode == mode_enum::group ? seqstart_k_buf.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
@@ -1462,11 +1652,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
float scale_p_host = 1.0f;
float scale_o_host = 1.0f;
if(qscale.type == quant_scale_enum::pertensor)
if constexpr(!is_mx)
{
scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0);
scale_p_host = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
scale_o_host = v_descale_host(0) / scale_p_host;
if(qscale.type == quant_scale_enum::pertensor)
{
scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0);
scale_p_host = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
scale_o_host = v_descale_host(0) / scale_p_host;
}
}
auto p_compute_element_func = [&]() {
@@ -1680,7 +1873,45 @@ fwd_result fmha_fwd_run(mode_enum mode,
#endif
// reference
if(qscale.type == quant_scale_enum::blockscale)
if constexpr(is_mx)
{
ck_tile::HostTensor<QScaleDataType> q_descale_host_ref(
{nhead, real_seqlen_q, hdim_q_scale});
ck_tile::HostTensor<KScaleDataType> k_descale_host_ref(
{nhead, real_seqlen_k, hdim_q_scale});
// clang-format off
if(i_perm) q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[0], i[1] + query_offset, i[2]); });
else q_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_descale_host(b_idx, i[1] + query_offset, i[0], i[2]); });
// clang-format on
// clang-format off
if(i_perm) k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_descale_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
// clang-format on
auto q_host_ref2 = ck_tile::reference_batched_mx_descale<QDataType,
QScaleDataType,
SaccDataType,
SaccDataType>(
q_host_ref, q_descale_host_ref, kQKScaleGranularity);
auto k_host_ref2 = ck_tile::reference_batched_mx_descale<KDataType,
KScaleDataType,
SaccDataType,
SaccDataType>(
k_host_ref, k_descale_host_ref, kQKScaleGranularity);
ck_tile::reference_batched_gemm<SaccDataType,
SaccDataType,
SaccDataType,
SMPLComputeDataType>(q_host_ref2,
k_host_ref2,
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale_s_host));
}
else if(qscale.type == quant_scale_enum::blockscale)
{
const ck_tile::index_t q_offset =
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb];
@@ -1881,6 +2112,32 @@ fwd_result fmha_fwd_run(mode_enum mode,
s_host_ref, p_host_ref, p_compute_element_func);
}
}
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
// Use smaller rtol/atol as LSE is computed and stored in fp32, so there is no
// precision loss due to conversion
bool cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
1e-4,
1e-4,
/* allow_infinity_ref = */ true);
pass &= cur_pass;
if(!cur_pass)
{
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
}
}
if(p_drop > 0)
{
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
@@ -1907,13 +2164,43 @@ fwd_result fmha_fwd_run(mode_enum mode,
randval_host_ref,
"DROPOUT RANDVAL Error: Incorrect results!");
pass &= cur_pass;
if(!cur_pass)
{
break;
}
}
if(qscale.type == quant_scale_enum::blockscale)
if constexpr(is_mx)
{
const ck_tile::index_t real_seqlen_v_scale =
seqstart_v_scale_host[wb + 1] - seqstart_v_scale_host[wb];
const ck_tile::index_t v_scale_offset =
mode == mode_enum::batch ? 0 : seqstart_v_scale_host[wb];
ck_tile::HostTensor<VScaleDataType> v_descale_host_ref(
{nhead, hdim_v, real_seqlen_v_scale});
// clang-format off
if(i_perm) v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[0] / nr, i[1], i[2] + v_scale_offset); });
else v_descale_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_descale_host(cache_b_idx, i[1], i[0] / nr, i[2] + v_scale_offset); });
// clang-format on
auto v_host_ref2 = ck_tile::reference_batched_mx_descale<VDataType,
VScaleDataType,
OaccDataType,
OaccDataType>(
v_host_ref, v_descale_host_ref, kVScaleGranularity);
// P is not quantized and then dequantized here (PDataType = float).
// On host softmax is computed for the whole row of S, while on device FA computes
// softmax and quantizes it in blocks of N0 values. Quantization on host would make
// reference results **less** precise than the device results for large seqlen_k!
ck_tile::reference_batched_gemm<PDataType, OaccDataType, OaccDataType, ODataType>(
p_host_ref,
v_host_ref2,
o_host_ref,
ck_tile::identity{},
ck_tile::identity{},
oacc_element_func);
}
else if(qscale.type == quant_scale_enum::blockscale)
{
const ck_tile::index_t v_offset =
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
@@ -1969,35 +2256,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
<< std::endl
<< "\tquery_offset used: " << query_offset << std::endl
<< "\tkey_offset used: " << key_offset << std::endl;
break;
}
if(lse)
if(!pass)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
pass &= cur_pass;
if(!cur_pass)
{
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
break;
}
break;
}
}
@@ -2006,33 +2269,37 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(json)
{
dump_fmha_fwd_json_results(
*json,
data_type,
mode == mode_enum::batch ? "batch" : "group",
io_layout(i_perm, o_perm),
batch,
nhead,
nhead_k,
seqlen_qs[0],
seqlen_ks[0],
seqlen_kpads[0],
hdim_q,
hdim_v,
scale_s,
p_drop,
lse,
qscale.type == quant_scale_enum::no_scale
? "no_scale"
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
bias.type == bias_enum::elementwise_bias
? "elementwise_bias"
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
is_v_rowmajor ? "r" : "c",
pass,
ave_time,
tflops,
gb_per_sec);
const std::string qscale_name =
(qscale.type == quant_scale_enum::no_scale ? "no_scale"
: qscale.type == quant_scale_enum::pertensor ? "pertensor"
: qscale.type == quant_scale_enum::blockscale ? "blockscale"
: qscale.type == quant_scale_enum::kv_blockscale ? "kv_blockscale"
: qscale.type == quant_scale_enum::mx ? "mx"
: "unknown");
dump_fmha_fwd_json_results(*json,
data_type,
mode == mode_enum::batch ? "batch" : "group",
io_layout(i_perm, o_perm),
batch,
nhead,
nhead_k,
seqlen_qs[0],
seqlen_ks[0],
seqlen_kpads[0],
hdim_q,
hdim_v,
scale_s,
p_drop,
lse,
qscale_name,
bias.type == bias_enum::elementwise_bias
? "elementwise_bias"
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
is_v_rowmajor ? "r" : "c",
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass ? fwd_result::success : fwd_result::failure;

View File

@@ -18,6 +18,7 @@ enum class quant_scale_enum
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
};
struct quant_scale_info
@@ -34,6 +35,8 @@ struct quant_scale_info
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
else if(type == quant_scale_enum::mx)
os << "mx";
}
static quant_scale_info decode(std::string str)
@@ -55,6 +58,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::kv_blockscale;
}
else if(str == "mx" || str == "4")
{
info.type = quant_scale_enum::mx;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);