Bf16*fp4 gemm (#2801)

* support bf16*mxfp4 gemm

* rebase bf16*fp4 example to develop branch

* Clean up commented debug code in GEMM kernel

* rename example folder

* support bf16*mxfp4 gemm

* rebase bf16*fp4 example to develop branch

* Clean up commented debug code in GEMM kernel

* rename example folder

* rebase to new develop

* fix clang format

* update code according to reviewer's comment

* Update README.md

* update code according to reviewer's comment

* update code according to reviewer's comment

* Update CMakeLists.txt

* Update README.md

* Update CMakeLists.txt

* Delete files

* Delete files

* Add unit tests

* Update test_gemm_quant_base.hpp

* merge bf16*fp4 example to develop branch

* fix clang format

* fix clang format

* Update CMakeLists.txt

* fix ci test

* fix clang format

* resolve conflicts

---------

Co-authored-by: eliotwang <charyang@smci355-ccs-aus-m10-29.cs-aus.dcgpu>
Co-authored-by: ShaoChunLee <Shao-Chun.Lee@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
eliotwang
2025-12-11 23:20:29 +08:00
committed by GitHub
parent ce99cab605
commit 715671e419
23 changed files with 1260 additions and 137 deletions

View File

@@ -96,8 +96,10 @@ struct BlockUniversalGemmAsBsCr
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
ADataType,
BDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

View File

@@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BDataType =
std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>, ADataType, BInDataType>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
@@ -270,12 +272,17 @@ struct GemmPipelineAgBgCrImplBase
}();
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
using BLdsDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
auto b_lds_load_tile_distr = []() {
if constexpr(is_b_load_tr)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename BLdsLoadTileDistr::DstrEncode,
typename Problem::BDataType>::TransposedDstrEncode{});
typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode,
BLdsDataType>::TransposedDstrEncode{});
else
return BLdsLoadTileDistr{};
}();

View File

@@ -303,8 +303,11 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
@@ -585,9 +588,12 @@ struct UniversalGemmBasePolicy
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BDataType = std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
if constexpr(Problem::FixedVectorSize)
{
@@ -729,13 +735,17 @@ struct UniversalGemmBasePolicy
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t KPerBlock = std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
? Problem::BlockGemmShape::kK / 2
: Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
? 4
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -841,10 +851,12 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b =
integer_least_multiple(sizeof(typename Problem::BDataType) *
Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
16);
using BDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16);
return smem_size_b;
}
@@ -882,8 +894,10 @@ struct UniversalGemmPipelineAgBgCrPolicy
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
ADataType,
BDataType>;
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,