[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 565fea2645
commit 6a6177a246
28 changed files with 642 additions and 175 deletions

View File

@@ -76,6 +76,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base
test_gemm_quant_abquant_a4w4_base.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_a4w4_base PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_padding
test_gemm_quant_abquant_a4w4_padding.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_a4w4_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_preshuffle
test_gemm_quant_abquant_a4w4_preshuffle.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_a4w4_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant
test_gemm_quant_abquant_preshuffleQuant.cpp
)

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using Half = ck_tile::half_t;
using PkFP4 = ck_tile::pk_fp4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// 1d block sizes for AQuant
using GroupSize1D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false
// RCR layout with RowMajor AQ, ColumnMajor BQ
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkFP4, PkFP4, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize1D, GroupSize2D, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using Half = ck_tile::half_t;
using PkFP4 = ck_tile::pk_fp4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// 1d block sizes for AQuant
using GroupSize1D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false
// RCR layout with RowMajor AQ, ColumnMajor BQ
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkFP4, PkFP4, float, Half, ABQuantGrouped, GemmConfigPadding, GroupSize1D, GroupSize2D, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadK)
{
this->run_test_with_validation(1024, 1024, 832);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadN)
{
this->run_test_with_validation(1024, 832, 1024);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadM)
{
this->run_test_with_validation(832, 1024, 1024);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadMNK)
{
this->run_test_with_validation(832, 832, 832);
}
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest_PadNK)
{
this->run_test_with_validation(1024, 832, 832);
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using Half = ck_tile::half_t;
using PkFP4 = ck_tile::pk_fp4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
// 1d block sizes for AQuant
using GroupSize1D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// RCR layout with RowMajor AQ, ColumnMajor BQ
// PreshuffleB = true && TransposeC = false
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, PkFP4, PkFP4, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize1D, GroupSize2D, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -209,7 +209,7 @@ template <>
struct QuantTypeTraits<ck_tile::QuantType::ABQuantGrouped>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
using ComputeDataType = void; // Use automatically determined compute type
static constexpr const char* name = "abquant";
};

View File

@@ -1174,8 +1174,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,