mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Bf16*fp4 gemm (#2801)
* support bf16*mxfp4 gemm * rebase bf16*fp4 example to develop branch * Clean up commented debug code in GEMM kernel * rename example folder * support bf16*mxfp4 gemm * rebase bf16*fp4 example to develop branch * Clean up commented debug code in GEMM kernel * rename example folder * rebase to new develop * fix clang format * update code according to reviewer's comment * Update README.md * update code according to reviewer's comment * update code according to reviewer's comment * Update CMakeLists.txt * Update README.md * Update CMakeLists.txt * Delete files * Delete files * Add unit tests * Update test_gemm_quant_base.hpp * merge bf16*fp4 example to develop branch * fix clang format * fix clang format * Update CMakeLists.txt * fix ci test * fix clang format * resolve conflicts --------- Co-authored-by: eliotwang <charyang@smci355-ccs-aus-m10-29.cs-aus.dcgpu> Co-authored-by: ShaoChunLee <Shao-Chun.Lee@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -96,8 +96,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
|
||||
@@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>, ADataType, BInDataType>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
@@ -270,12 +272,17 @@ struct GemmPipelineAgBgCrImplBase
|
||||
}();
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
|
||||
|
||||
using BLdsDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
auto b_lds_load_tile_distr = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename BLdsLoadTileDistr::DstrEncode,
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode,
|
||||
BLdsDataType>::TransposedDstrEncode{});
|
||||
|
||||
else
|
||||
return BLdsLoadTileDistr{};
|
||||
}();
|
||||
|
||||
@@ -303,8 +303,11 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -585,9 +588,12 @@ struct UniversalGemmBasePolicy
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BDataType = std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
if constexpr(Problem::FixedVectorSize)
|
||||
{
|
||||
@@ -729,13 +735,17 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t KPerBlock = std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
|
||||
? Problem::BlockGemmShape::kK / 2
|
||||
: Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
|
||||
? 4
|
||||
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -841,10 +851,12 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t smem_size_b =
|
||||
integer_least_multiple(sizeof(typename Problem::BDataType) *
|
||||
Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
@@ -882,8 +894,10 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
|
||||
Reference in New Issue
Block a user