Files
composable_kernel/example/ck_tile/01_fmha/quant.hpp
Anton Gorenko 2312eef6c3 [rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on
 gfx950 (#4368)

## Motivation

Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline

## Technical Details

The microscaling is used when quant scale mode is
`BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are
fp8/bf8/fp4.

Supported features:
* only "qr" pipeline is implemented
* hdim 128 and 256 (smaller hdim are not possible due to restrictions of
"qr" pipeline, but they can be computed using instances with padding)
 * both 32x32x64 and 16x16x128 scale MFMAs are supported
 * Q and K scales are applied in hdim, V scales - in seqlen dimension
 * column-major V only
 * batch and group mode
 * bias, Alibi (tested but no instances by default, just like fp8)
 * masking etc.

Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008

## Test Plan

```
ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8
ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4
```

## Test Result

The tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-11 10:00:52 +00:00

79 lines
2.0 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
mx = 4, // Microscaling (MX)
};
struct quant_scale_info
{
quant_scale_enum type;
void serialize(std::ostream& os) const
{
if(type == quant_scale_enum::no_scale)
os << "n";
else if(type == quant_scale_enum::pertensor)
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
else if(type == quant_scale_enum::mx)
os << "mx";
}
static quant_scale_info decode(std::string str)
{
quant_scale_info info{quant_scale_enum::no_scale};
if(str == "n" || str == "0")
{
info.type = quant_scale_enum::no_scale;
}
else if(str == "pt" || str == "1")
{
info.type = quant_scale_enum::pertensor;
}
else if(str == "bs" || str == "2")
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "kvbs" || str == "3")
{
info.type = quant_scale_enum::kv_blockscale;
}
else if(str == "mx" || str == "4")
{
info.type = quant_scale_enum::mx;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
{
qsi.serialize(os);
return os;
}
};
#pragma clang diagnostic pop