[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)

[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on
 gfx950 (#4368)

## Motivation

Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline

## Technical Details

The microscaling is used when quant scale mode is
`BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are
fp8/bf8/fp4.

Supported features:
* only "qr" pipeline is implemented
* hdim 128 and 256 (smaller hdim are not possible due to restrictions of
"qr" pipeline, but they can be computed using instances with padding)
 * both 32x32x64 and 16x16x128 scale MFMAs are supported
 * Q and K scales are applied in hdim, V scales - in seqlen dimension
 * column-major V only
 * batch and group mode
 * bias, Alibi (tested but no instances by default, just like fp8)
 * masking etc.

Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008

## Test Plan

```
ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8
ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4
```

## Test Result

The tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Anton Gorenko
2026-03-11 10:00:52 +00:00
committed by assistant-librarian[bot]
parent c85c272c39
commit 2312eef6c3
29 changed files with 2167 additions and 356 deletions

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename InDataType,
typename ScaleDataType,
typename OutDataType,
typename ComputeDataType>
CK_TILE_HOST HostTensor<OutDataType>
reference_batched_mx_descale(const HostTensor<InDataType>& a_b_m_k,
const HostTensor<ScaleDataType>& scales_b_m_ks,
const std::size_t scale_granularity)
{
const std::size_t B = a_b_m_k.get_length(0);
const std::size_t M = a_b_m_k.get_length(1);
const std::size_t K = a_b_m_k.get_length(2);
HostTensor<ComputeDataType> a_b_m_k_scaled(a_b_m_k.get_lengths());
auto f = [&](auto batch) {
constexpr index_t packed_size = ck_tile::numeric_traits<InDataType>::PackedSize;
for(std::size_t m = 0; m < M; ++m)
{
for(std::size_t k = 0; k < K; k += packed_size)
{
const auto scale = ck_tile::type_convert<ComputeDataType>(
scales_b_m_ks(batch, m, k / scale_granularity));
if constexpr(std::is_same_v<InDataType, pk_fp4_t>)
{
auto a_f4x2 = a_b_m_k(batch, m, k);
auto a_f4_lo = ck_tile::type_convert<ComputeDataType>(
a_f4x2.template unpack<>(number<0>{}));
auto a_f4_hi = ck_tile::type_convert<ComputeDataType>(
a_f4x2.template unpack<>(number<1>{}));
a_b_m_k_scaled(batch, m, k) = a_f4_lo * scale;
a_b_m_k_scaled(batch, m, k + 1) = a_f4_hi * scale;
}
else
{
a_b_m_k_scaled(batch, m, k) =
ck_tile::type_convert<ComputeDataType>(a_b_m_k(batch, m, k)) * scale;
}
}
}
};
make_ParallelTensorFunctor(f, B)(std::thread::hardware_concurrency());
return a_b_m_k_scaled;
}
} // namespace ck_tile