[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)

[CK_TILE] Extend support of mix precision microscaling BQuant
 (#4267)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Supported types combinations using BQuant=e8m0:
 - A=bf16
 - B=bf16,bf8,fp4

Summary:
- remove usage of `pk_fp4_raw_t`: consistent with other implementations
and avoid taking into account of the packed size explicitly. In general,
the raw type should not be used because CK Tile internally takes care of
the PackedSize, so using the raw type adds unnecessary complexity to the
implementation
- handle microscaling by checking for `e8m0` type for BQuant (previous
implementation was inconsistent)
 - add support for scaling instructions in `DequantPack8`
 - mx pipeline:
   - extend existing pipeline to support different B types
- add support to scale and cast before writing to LDS or after reading
from LDS (this can be defined in the `Problem` by the user)
 - block gemm:
   - mx pipeline is now using block gemm BQuant
- block gemm BQuant can now load from LDS and apply scale and then call
block gemm universal operator. This adds new functionalities and remove
code duplication
 - warp gemm:
- add case to support 128bit ds_read/write for both A and B when A=16bit
and B=8bit
- add examples and tests: note that some tests for bf16/fp4 already
existed but were removed during previous tests refactoring. I added them
again and other relevant tests for new types combinations

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Enrico Degregori
2026-02-24 17:57:02 +00:00
committed by assistant-librarian[bot]
parent 3af1a0aafc
commit 4c626aeaa6
44 changed files with 2061 additions and 683 deletions

View File

@@ -388,49 +388,56 @@ template <typename ADataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
typename BLayout,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
AccDataType pasual = 0;
for(std::size_t k = 0; k < (K / 2); k++)
{
using ComputeType = float;
auto b_scale = type_convert<int32_t>(q((2 * k) / QuantGroupSize::kK, n)) - 127;
ComputeType v_a_0, v_a_1;
ComputeType v_b_0, v_b_1;
AccDataType v_acc = 0;
using ComputeType = float;
ComputeType v_a;
ComputeType v_b;
v_a_0 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k))));
v_a_1 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k + 1))));
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
auto load_b = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
auto b_scale_fp4 = type_convert<float>(std::pow(2.0f, b_scale));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
v_b_0 = type_convert<ComputeType>(b_f4_lo) * b_scale_fp4;
v_b_1 = type_convert<ComputeType>(b_f4_hi) * b_scale_fp4;
const auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return (n & 1) ? type_convert<ComputeType>(b_pack.unpack(number<1>{}))
: type_convert<ComputeType>(b_pack.unpack(number<0>{}));
}
else
{
return (k & 1) ? type_convert<ComputeType>(b_pack.unpack(number<1>{}))
: type_convert<ComputeType>(b_pack.unpack(number<0>{}));
}
}
else
{
return ck_tile::type_convert<ComputeType>(b_element_op(b_k_n(k, n)));
}
};
pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1;
v_acc += pasual;
for(std::size_t k = 0; k < K; k++)
{
const auto b_scale = type_convert<float>(q(k / QuantGroupSize::kK, n));
v_a = ck_tile::type_convert<ComputeType>(a_element_op(a_m_k(m, k)));
v_b = load_b(k) * b_scale;
v_acc += v_a * v_b;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};