[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)

[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on
 gfx950 (#4368)

## Motivation

Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline

## Technical Details

The microscaling is used when quant scale mode is
`BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are
fp8/bf8/fp4.

Supported features:
* only "qr" pipeline is implemented
* hdim 128 and 256 (smaller hdim are not possible due to restrictions of
"qr" pipeline, but they can be computed using instances with padding)
 * both 32x32x64 and 16x16x128 scale MFMAs are supported
 * Q and K scales are applied in hdim, V scales - in seqlen dimension
 * column-major V only
 * batch and group mode
 * bias, Alibi (tested but no instances by default, just like fp8)
 * masking etc.

Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008

## Test Plan

```
ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8
ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4
```

## Test Result

The tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Anton Gorenko
2026-03-11 10:00:52 +00:00
committed by assistant-librarian[bot]
parent c85c272c39
commit 2312eef6c3
29 changed files with 2167 additions and 356 deletions

View File

@@ -7,15 +7,17 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
set(TEST_NAME "test_ck_tile_fmha")
function(add_gtest_fwd test_group)
if((GPU_TARGETS MATCHES "gfx90a" AND CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx9[45]|gfx12")
set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32")
elseif((GPU_TARGETS MATCHES "gfx90a" AND NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx11")
set(V_TYPES "fp16" "bf16" "fp32")
set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32" "mxfp8" "mxfp4")
if(GPU_TARGETS MATCHES "gfx908|gfx90a" AND NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
# fp8 instances are built for all gfx9, do not test on archs without hardware support
list(REMOVE_ITEM V_TYPES "fp8bf16")
endif()
set(CPP_TYPE_fp16 "FmhaFwdFp16")
set(CPP_TYPE_bf16 "FmhaFwdBf16")
set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16")
set(CPP_TYPE_fp32 "FmhaFwdFp32")
set(CPP_TYPE_mxfp8 "FmhaFwdMxFp8")
set(CPP_TYPE_mxfp4 "FmhaFwdMxFp4")
set(sources)
if(TARGET ${FMHA_FWD_INSTANCES})

View File

@@ -1,6 +1,11 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <algorithm>
#include <array>
#include <cmath>
#include <vector>
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
@@ -42,6 +47,7 @@ struct TestConfigs
static constexpr auto qscale_str = "n";
static constexpr bool def_lse = true;
static constexpr bool def_is_v_rowmajor = true;
static constexpr auto init_method = "uf";
static int adjust_seqlen(int seqlen) { return seqlen; }
};
@@ -57,11 +63,45 @@ struct TestConfigs<FmhaFwdFp8Bf16>
static constexpr auto qscale_str = "pt";
static constexpr bool def_lse = false;
static constexpr bool def_is_v_rowmajor = true;
static constexpr auto init_method = "3";
// When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests:
// return ck_tile::integer_least_multiple(seqlen, 128);
static int adjust_seqlen(int seqlen) { return seqlen; }
};
template <>
struct TestConfigs<FmhaFwdMxFp8>
{
static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}};
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
static constexpr auto IsVRowmajorValues = std::array{false};
static constexpr auto qscale_str = "mx";
static constexpr bool def_lse = true;
static constexpr bool def_is_v_rowmajor = false;
static constexpr auto init_method = "3";
static int adjust_seqlen(int seqlen) { return seqlen; }
};
template <>
struct TestConfigs<FmhaFwdMxFp4>
{
static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}};
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
static constexpr auto IsVRowmajorValues = std::array{false};
static constexpr auto qscale_str = "mx";
static constexpr bool def_lse = true;
static constexpr bool def_is_v_rowmajor = false;
static constexpr auto init_method = "3";
static int adjust_seqlen(int seqlen)
{
return seqlen < 0 ? seqlen : ck_tile::integer_least_multiple(seqlen, 2);
}
};
template <>
struct TestConfigs<FmhaFwdFp32>
{
@@ -81,6 +121,7 @@ struct TestConfigs<FmhaFwdFp32>
static constexpr auto qscale_str = "n";
static constexpr bool def_lse = true;
static constexpr bool def_is_v_rowmajor = true;
static constexpr auto init_method = "uf";
static int adjust_seqlen(int seqlen) { return seqlen; }
};
@@ -92,8 +133,8 @@ static auto IsVRowmajorValues = ValuesIn(TestConfigs<DataTypeConfig>::IsVRowm
constexpr static auto qscale_str = TestConfigs<DataTypeConfig>::qscale_str;
constexpr bool def_lse = TestConfigs<DataTypeConfig>::def_lse;
constexpr bool def_is_v_rowmajor = TestConfigs<DataTypeConfig>::def_is_v_rowmajor;
constexpr auto init_method = TestConfigs<DataTypeConfig>::init_method;
int adjust_seqlen(int seqlen) { return TestConfigs<DataTypeConfig>::adjust_seqlen(seqlen); }
constexpr auto init_method = "uf";
// Random seed used for initializing input tensors. 0 for non-deterministic seed
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
@@ -901,12 +942,6 @@ using PaddingParam = std::tuple<mode_enum, // mode
bool, // o_perm
std::string>; // mask_str
// Ensure headers for containers / algorithms used in padding param builder.
#include <vector>
#include <array>
#include <cmath>
#include <algorithm>
class PaddingCases : public TestWithParam<PaddingParam>
{
};
@@ -918,6 +953,12 @@ static std::vector<PaddingParam> BuildPaddingParams()
{
std::vector<PaddingParam> params;
if constexpr(ck_tile::is_any_of<DataTypeConfig, FmhaFwdFp8Bf16, FmhaFwdMxFp8, FmhaFwdMxFp4>::
value)
{
return params;
}
// mask variants to cover
const std::vector<std::string> mask_variants{"0", "t:50,64", "b:32,40"};
const std::vector<std::string> mask_variants_reduced{"0", "t:50,64"}; // used for trimmed sets
@@ -1106,15 +1147,10 @@ static std::vector<PaddingParam> BuildPaddingParams()
static const std::vector<PaddingParam> kPaddingParams = BuildPaddingParams();
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams));
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, PaddingCases, ValuesIn(kPaddingParams));
TEST_P(PaddingCases, DataTypeConfig)
{
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
{
GTEST_SKIP() << "Skip for fp8";
}
auto [mode,
batch,
nhead,