Bf16*fp4 gemm (#2801)

* support bf16*mxfp4 gemm

* rebase bf16*fp4 example to develop branch

* Clean up commented debug code in GEMM kernel

* rename example folder

* support bf16*mxfp4 gemm

* rebase bf16*fp4 example to develop branch

* Clean up commented debug code in GEMM kernel

* rename example folder

* rebase to new develop

* fix clang format

* update code according to reviewer's comment

* Update README.md

* update code according to reviewer's comment

* update code according to reviewer's comment

* Update CMakeLists.txt

* Update README.md

* Update CMakeLists.txt

* Delete files

* Delete files

* Add unit tests

* Update test_gemm_quant_base.hpp

* merge bf16*fp4 example to develop branch

* fix clang format

* fix clang format

* Update CMakeLists.txt

* fix ci test

* fix clang format

* resolve conflicts

---------

Co-authored-by: eliotwang <charyang@smci355-ccs-aus-m10-29.cs-aus.dcgpu>
Co-authored-by: ShaoChunLee <Shao-Chun.Lee@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
eliotwang
2025-12-11 23:20:29 +08:00
committed by GitHub
parent ce99cab605
commit 715671e419
23 changed files with 1260 additions and 137 deletions

0
test/ck_tile/gemm_block_scale/CMakeLists.txt Normal file → Executable file
View File

View File

@@ -131,8 +131,10 @@ class TestCkTileGemmQuantBase : public ::testing::Test
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>;
using ComputeType = std::conditional_t<
std::is_same_v<BDataType_, ck_tile::pk_fp4_raw_t>,
ADataType_,
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType_, AccDataType_>(
ck_tile::integer_divide_ceil(K, kbatch));

View File

@@ -16,9 +16,12 @@ using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BF16 = ck_tile::bf16_t;
using UInt8 = ck_tile::pk_fp4_raw_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
using GroupSize32 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
// 2d block sizes for BQuant
using GroupSize2D8N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
@@ -42,6 +45,9 @@ using BQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, UInt8, UInt8, BF16, BQuantGrouped, GemmConfigMxFp4, GroupSize32>,
// 2d cases with grouping also on the n axis
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D8N>,

View File

@@ -60,6 +60,13 @@ struct GemmConfigPrefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Tile = 128;
};
struct GemmConfigMxFp4 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
};
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;
@@ -403,7 +410,8 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_B =
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? (K / 2) : K;
const ck_tile::index_t stride_C = N;
// BQuant uses block/grouped quantization for B matrix
@@ -414,15 +422,27 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? K / 2 : K,
N,
stride_B,
this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{})));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{125.f, 130.f}(bq_bqk_bqn);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
}
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
@@ -501,13 +521,22 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
c_m_n_host_ref.SetZero();
// Run reference BQuant implementation
ck_tile::reference_gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
ck_tile::reference_mxfp4gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
else
ck_tile::reference_gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
@@ -580,33 +609,37 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
std::conditional_t<PreshuffleB == false,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
using GemmPipeline = std::conditional_t<
PreshuffleB == false,
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledMMAPermuteN>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
ADataType,
BDataType>,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledMMAPermuteN>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,