mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
79 lines
2.0 KiB
C++
79 lines
2.0 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <string>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/fmha.hpp"
|
|
|
|
#pragma clang diagnostic push
|
|
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
|
|
|
// keep sync with BlockAttentionQuantScaleEnum
|
|
enum class quant_scale_enum
|
|
{
|
|
no_scale = 0,
|
|
pertensor = 1,
|
|
blockscale = 2,
|
|
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
|
|
mx = 4, // Microscaling (MX)
|
|
};
|
|
|
|
struct quant_scale_info
|
|
{
|
|
quant_scale_enum type;
|
|
|
|
void serialize(std::ostream& os) const
|
|
{
|
|
if(type == quant_scale_enum::no_scale)
|
|
os << "n";
|
|
else if(type == quant_scale_enum::pertensor)
|
|
os << "pt";
|
|
else if(type == quant_scale_enum::blockscale)
|
|
os << "bs";
|
|
else if(type == quant_scale_enum::kv_blockscale)
|
|
os << "kvbs";
|
|
else if(type == quant_scale_enum::mx)
|
|
os << "mx";
|
|
}
|
|
|
|
static quant_scale_info decode(std::string str)
|
|
{
|
|
quant_scale_info info{quant_scale_enum::no_scale};
|
|
if(str == "n" || str == "0")
|
|
{
|
|
info.type = quant_scale_enum::no_scale;
|
|
}
|
|
else if(str == "pt" || str == "1")
|
|
{
|
|
info.type = quant_scale_enum::pertensor;
|
|
}
|
|
else if(str == "bs" || str == "2")
|
|
{
|
|
info.type = quant_scale_enum::blockscale;
|
|
}
|
|
else if(str == "kvbs" || str == "3")
|
|
{
|
|
info.type = quant_scale_enum::kv_blockscale;
|
|
}
|
|
else if(str == "mx" || str == "4")
|
|
{
|
|
info.type = quant_scale_enum::mx;
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid quant scale value: " + str);
|
|
}
|
|
return info;
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
|
|
{
|
|
qsi.serialize(os);
|
|
return os;
|
|
}
|
|
};
|
|
#pragma clang diagnostic pop
|