[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)

[CK_TILE] Extend support of mix precision microscaling BQuant
 (#4267)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Supported types combinations using BQuant=e8m0:
 - A=bf16
 - B=bf16,bf8,fp4

Summary:
- remove usage of `pk_fp4_raw_t`: consistent with other implementations
and avoid taking into account of the packed size explicitly. In general,
the raw type should not be used because CK Tile internally takes care of
the PackedSize, so using the raw type adds unnecessary complexity to the
implementation
- handle microscaling by checking for `e8m0` type for BQuant (previous
implementation was inconsistent)
 - add support for scaling instructions in `DequantPack8`
 - mx pipeline:
   - extend existing pipeline to support different B types
- add support to scale and cast before writing to LDS or after reading
from LDS (this can be defined in the `Problem` by the user)
 - block gemm:
   - mx pipeline is now using block gemm BQuant
- block gemm BQuant can now load from LDS and apply scale and then call
block gemm universal operator. This adds new functionalities and remove
code duplication
 - warp gemm:
- add case to support 128bit ds_read/write for both A and B when A=16bit
and B=8bit
- add examples and tests: note that some tests for bf16/fp4 already
existed but were removed during previous tests refactoring. I added them
again and other relevant tests for new types combinations

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Enrico Degregori
2026-02-24 17:57:02 +00:00
committed by assistant-librarian[bot]
parent 3af1a0aafc
commit 4c626aeaa6
44 changed files with 2061 additions and 683 deletions

View File

@@ -20,7 +20,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
gemm_aquant_quantgrouped_preshufflequant.cpp
gemm_bquant_quantgrouped_bf8i4.cpp
gemm_bquant_quantgrouped_fp8i4.cpp
gemm_bquant_quantgrouped_bf16mxfp4.cpp
gemm_bquant_quantgrouped_mx_bf16fp4.cpp
gemm_bquant_quantgrouped_mx_bf16bf8.cpp
gemm_bquant_quantgrouped_mx_bf16bf16.cpp
gemm_bquant_quantgrouped_bf8.cpp
gemm_bquant_quantgrouped_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp

View File

@@ -53,7 +53,7 @@ args:
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)

View File

@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::bf16_t,
ck_tile::bf16_t,
ck_tile::e8m0_t>{});
lut[hash_multiple_strings(
{"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
return 0;
}();

View File

@@ -0,0 +1,34 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "run_gemm_quant_example.inc"
using GemmConfig = GemmConfigMixedPrecision;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::bf8_t,
ck_tile::bf16_t,
ck_tile::e8m0_t>{});
lut[hash_multiple_strings(
{"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
return 0;
}();

View File

@@ -6,33 +6,33 @@
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>, \
TypeConfig, \
QuantGroupSize, \
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);
static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t,
ck_tile::pk_fp4_t,
ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t>{});
ck_tile::e8m0_t>{});
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] =
{"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
{"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
{"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;

View File

@@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
.insert("prec",
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4; for ABQuant: fp8, bf8, fp4")
" mxbf16bf16, mxbf16bf8, mxbf16fp4 or bf8i4; for ABQuant: fp8, bf8, fp4")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -45,7 +45,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
const float max_accumulated_value)
{
using ComputeType = std::conditional_t<
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
ADataType,
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>>;
// Calculate thresholds
@@ -278,6 +278,24 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
static constexpr bool TransposeC = true;
};
// Used for A=16bit and B=8bit. The warp tile has KPack=16
// Matrix A: Vectorsize = 8, KPack=16 -> LDS read/write vectorsize = 8 (128bit)
// Matrix B: Vectorsize = 16, KPack=16 -> LDS read/write vectorsize = 16 (128bit)
struct GemmConfigMixedPrecision : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
};
template <typename PrecType>
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType>
{

View File

@@ -108,6 +108,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto b_cast_policy =
std::is_same_v<typename TypeConfig::ADataType, typename TypeConfig::BDataType>
? ck_tile::CastPolicy::BeforeLDSWrite
: ck_tile::CastPolicy::AfterLDSRead;
// row-col and tensor quants use the regular pipeline, A/B/AB quants use their own
using PipelineProblem = std::conditional_t<
@@ -150,7 +154,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
tail_number_v,
b_cast_policy>,
ck_tile::GemmABQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType, // For AQ
typename TypeConfig::BDataType,
@@ -173,10 +178,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using BQuantPipeline = std::conditional_t<
GemmConfig::PreshuffleB,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
std::conditional_t<std::is_same_v<typename TypeConfig::QDataType, ck_tile::e8m0_t>,
ck_tile::MicroscaleGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using ABQuantPipeline = std::conditional_t<
eight_warps,
@@ -257,11 +261,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t> ? args.K / 2
: args.K,
args.N,
args.stride_B,
is_row_major(BLayout{})));
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
@@ -495,11 +495,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
int rotating_count = arg_parser.get_int("rotating_count");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(
(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
N,
stride_B,
is_row_major(b_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
// Conditional stride calculation based on QuantMode
@@ -531,11 +527,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
N,
stride_B,
is_row_major(b_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
@@ -575,18 +568,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
if constexpr(std::is_same_v<BQDataType, ck_tile::e8m0_t>)
{
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
// e8m0_t is basically an exponent of float32
ck_tile::HostTensor<float> pow2(scales.get_lengths());
ck_tile::FillUniformDistributionIntegerValue<float>{
range_min, range_max, fill_seed(gen)}(pow2);
scales.ForEach([&](auto& self, const auto& i) {
self(i) = static_cast<BQDataType>(std::exp2(pow2(i)));
});
};
gen_scales(*bq_tensor_ptr, -2, 2);
}
else
{
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
@@ -850,18 +856,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
ck_tile::reference_mxfp4gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
BQuantGroupSize,
false>(
if constexpr(std::is_same_v<BQDataType, ck_tile::e8m0_t>)
ck_tile::reference_mx_gemm_bquant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
BQuantGroupSize,
BLayout,
false>(
a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
else
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
@@ -961,7 +968,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_t>) &&
GemmConfig::PreshuffleB)
{
throw std::runtime_error(