mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c85c272c39
commit
2312eef6c3
@@ -7,15 +7,17 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
set(TEST_NAME "test_ck_tile_fmha")
|
||||
|
||||
function(add_gtest_fwd test_group)
|
||||
if((GPU_TARGETS MATCHES "gfx90a" AND CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32")
|
||||
elseif((GPU_TARGETS MATCHES "gfx90a" AND NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) OR GPU_TARGETS MATCHES "gfx11")
|
||||
set(V_TYPES "fp16" "bf16" "fp32")
|
||||
set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32" "mxfp8" "mxfp4")
|
||||
if(GPU_TARGETS MATCHES "gfx908|gfx90a" AND NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
# fp8 instances are built for all gfx9, do not test on archs without hardware support
|
||||
list(REMOVE_ITEM V_TYPES "fp8bf16")
|
||||
endif()
|
||||
set(CPP_TYPE_fp16 "FmhaFwdFp16")
|
||||
set(CPP_TYPE_bf16 "FmhaFwdBf16")
|
||||
set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16")
|
||||
set(CPP_TYPE_fp32 "FmhaFwdFp32")
|
||||
set(CPP_TYPE_mxfp8 "FmhaFwdMxFp8")
|
||||
set(CPP_TYPE_mxfp4 "FmhaFwdMxFp4")
|
||||
|
||||
set(sources)
|
||||
if(TARGET ${FMHA_FWD_INSTANCES})
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
|
||||
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
|
||||
|
||||
@@ -42,6 +47,7 @@ struct TestConfigs
|
||||
static constexpr auto qscale_str = "n";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static constexpr auto init_method = "uf";
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
@@ -57,11 +63,45 @@ struct TestConfigs<FmhaFwdFp8Bf16>
|
||||
static constexpr auto qscale_str = "pt";
|
||||
static constexpr bool def_lse = false;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static constexpr auto init_method = "3";
|
||||
// When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests:
|
||||
// return ck_tile::integer_least_multiple(seqlen, 128);
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdMxFp8>
|
||||
{
|
||||
static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{false};
|
||||
static constexpr auto qscale_str = "mx";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = false;
|
||||
static constexpr auto init_method = "3";
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdMxFp4>
|
||||
{
|
||||
static constexpr auto HDimValues = std::array{std::tuple{128, -1}, std::tuple{256, -1}};
|
||||
static constexpr auto SplitKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto AppendKVHDimValues = std::array<std::tuple<int, int>, 0>{};
|
||||
static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group};
|
||||
static constexpr auto IsVRowmajorValues = std::array{false};
|
||||
static constexpr auto qscale_str = "mx";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = false;
|
||||
static constexpr auto init_method = "3";
|
||||
static int adjust_seqlen(int seqlen)
|
||||
{
|
||||
return seqlen < 0 ? seqlen : ck_tile::integer_least_multiple(seqlen, 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TestConfigs<FmhaFwdFp32>
|
||||
{
|
||||
@@ -81,6 +121,7 @@ struct TestConfigs<FmhaFwdFp32>
|
||||
static constexpr auto qscale_str = "n";
|
||||
static constexpr bool def_lse = true;
|
||||
static constexpr bool def_is_v_rowmajor = true;
|
||||
static constexpr auto init_method = "uf";
|
||||
static int adjust_seqlen(int seqlen) { return seqlen; }
|
||||
};
|
||||
|
||||
@@ -92,8 +133,8 @@ static auto IsVRowmajorValues = ValuesIn(TestConfigs<DataTypeConfig>::IsVRowm
|
||||
constexpr static auto qscale_str = TestConfigs<DataTypeConfig>::qscale_str;
|
||||
constexpr bool def_lse = TestConfigs<DataTypeConfig>::def_lse;
|
||||
constexpr bool def_is_v_rowmajor = TestConfigs<DataTypeConfig>::def_is_v_rowmajor;
|
||||
constexpr auto init_method = TestConfigs<DataTypeConfig>::init_method;
|
||||
int adjust_seqlen(int seqlen) { return TestConfigs<DataTypeConfig>::adjust_seqlen(seqlen); }
|
||||
constexpr auto init_method = "uf";
|
||||
|
||||
// Random seed used for initializing input tensors. 0 for non-deterministic seed
|
||||
CK_TILE_DECLARE_ENV_VAR(CK_TILE_TEST_SEED, uint64_t, 123456)
|
||||
@@ -901,12 +942,6 @@ using PaddingParam = std::tuple<mode_enum, // mode
|
||||
bool, // o_perm
|
||||
std::string>; // mask_str
|
||||
|
||||
// Ensure headers for containers / algorithms used in padding param builder.
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
class PaddingCases : public TestWithParam<PaddingParam>
|
||||
{
|
||||
};
|
||||
@@ -918,6 +953,12 @@ static std::vector<PaddingParam> BuildPaddingParams()
|
||||
{
|
||||
std::vector<PaddingParam> params;
|
||||
|
||||
if constexpr(ck_tile::is_any_of<DataTypeConfig, FmhaFwdFp8Bf16, FmhaFwdMxFp8, FmhaFwdMxFp4>::
|
||||
value)
|
||||
{
|
||||
return params;
|
||||
}
|
||||
|
||||
// mask variants to cover
|
||||
const std::vector<std::string> mask_variants{"0", "t:50,64", "b:32,40"};
|
||||
const std::vector<std::string> mask_variants_reduced{"0", "t:50,64"}; // used for trimmed sets
|
||||
@@ -1106,15 +1147,10 @@ static std::vector<PaddingParam> BuildPaddingParams()
|
||||
|
||||
static const std::vector<PaddingParam> kPaddingParams = BuildPaddingParams();
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams));
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, PaddingCases, ValuesIn(kPaddingParams));
|
||||
|
||||
TEST_P(PaddingCases, DataTypeConfig)
|
||||
{
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
||||
{
|
||||
GTEST_SKIP() << "Skip for fp8";
|
||||
}
|
||||
|
||||
auto [mode,
|
||||
batch,
|
||||
nhead,
|
||||
|
||||
Reference in New Issue
Block a user