From f96d5b74b2cc990584c3dfb5d85602e6c811f71a Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 6 Aug 2025 04:20:20 -0500 Subject: [PATCH] udpdate --- .../gemm_mx_fp4_basic.cpp | 6 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 35 ++-- .../ops/gemm_mx/kernel/gemm_mx_kernel.hpp | 157 +++++++++++++----- .../gemm_mx_pipeline_ag_bg_cr_policy.hpp | 15 +- .../pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp | 9 +- .../ops/gemm_mx/pipeline/gemm_mx_utils.hpp | 2 +- 6 files changed, 162 insertions(+), 62 deletions(-) diff --git a/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp b/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp index 9a1885e428..4b1472605f 100644 --- a/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp +++ b/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp @@ -25,7 +25,7 @@ template -float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) +float gemm_mx_calc(const ck_tile::GemmMXHostArgs& args, const ck_tile::stream_config& s) { constexpr bool kPadM = false; constexpr bool kPadN = false; @@ -37,7 +37,7 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea constexpr ck_tile::index_t M_Tile = 64; constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 256; + constexpr ck_tile::index_t K_Tile = 128; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -45,7 +45,7 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea constexpr ck_tile::index_t M_Warp_Tile = 16; constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 128; + constexpr ck_tile::index_t K_Warp_Tile = 256; using CodegenGemmShape = ck_tile::TileGemmShape, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 15f3358aad..7b56410d38 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -169,10 +169,10 @@ struct UniversalGemmBasePolicy constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t VecLoadSize = GetVectorSizeB(); using TileEncodingPattern = TileDistributionEncodingPattern2D; + KPerBlock, + NPerBlock, + VecLoadSize, + BTileAccessPattern>; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -324,7 +324,14 @@ struct UniversalGemmBasePolicy else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) { - return (PackedSize * 16 / sizeof(DataType)); + if constexpr(std::is_same_v(remove_cvref_t, ck_tile::pk_fp4_t >)) + { + return 16; // special procssing for packed fp4 to avoid re-packing + } + else + { + return (PackedSize * 16 / sizeof(DataType)); + } } else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) @@ -636,15 +643,15 @@ struct UniversalGemmPipelineAgBgCrPolicy : WGAttrNumAccessEnum::Invalid; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + Problem::UseStructuredSparsity, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; using CDataType = remove_cvref_t; + using APackedSize = remove_cvref_t; + using BPackedSize = remove_cvref_t; + using BlockScaleSize = remove_cvref_t; + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); static constexpr auto I3 = number<3>(); + static constexpr auto I4 = number<4>(); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -136,22 +145,23 @@ struct GemmMXKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - CK_TILE_HOST static constexpr AQuantGemmKernelArgs - MakeKernelArgs(const AQuantGemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr GemmMXKernelArgs MakeKernelArgs(const GemmMXHostArgs& hostArgs) { - return AQuantGemmKernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.aq_ptr, - hostArgs.c_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.QK, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_C, - hostArgs.stride_AQ, - hostArgs.k_batch}; + return GemmMXKernelArgs{hostArgs.a_ptr, + hostArs.a_scale_ptr_, + hostArgs.b_ptr, + hostArgs.b_scale_ptr_, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K / APackedSize, + hostArgs.stride_A / APackedSize, + hostArgs.stride_scale_A, + hostArgs.stride_B / BPackedSize, + hostArgs.stride_scale_B, + hostArgs.stride_C, + hostArgs.stride_AQ, + hostArgs.k_batch}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -194,10 +204,20 @@ struct GemmMXKernel { splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); } + + // Calculate A scale offset + a_scale_k_split_offset = __builtin_amdgcn_readfirstlane( + k_id * kargs.KRead / (BlockScaleSize / APackedSize) * MXdlPack * NPerXdl) + + // Caluculate B scale offset + b_scale_k_split_offset = __builtin_amdgcn_readfirstlane( + k_id * kargs.KRead / (BlockScaleSize / BPackedSize) * NXdlPack * NPerXdl); } index_t a_k_split_offset; index_t b_k_split_offset; + index_t a_scale_k_split_offset; + index_t b_scale_k_split_offset; index_t splitted_k; }; @@ -351,8 +371,9 @@ struct GemmMXKernel template CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const AScaleDataType* a_scale_ptr, const BDataType* b_ptr, - const AQDataType* aq_ptr, + const BScaleDataType* b_scale_ptr, CDataType* c_ptr, const AQuantGemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) @@ -379,16 +400,27 @@ struct GemmMXKernel } }(); - const auto& aq_tensor_view = [&]() { - static_assert(std::is_same_v); + // A scale tensor view + const auto& a_scale_tensor_view = [&]() { + static_asssert(std::is_same_v); return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK), - make_tuple(kargs.stride_AQ, 1), + a_scale_ptr, + make_tuple(kargs.M, kargs.K / BlockScaleSize), + make_tuple(kargs.stride_scale_A, 1), number{}, number<1>{}); }(); + // const auto& aq_tensor_view = [&]() { + // static_assert(std::is_same_v); + // return make_naive_tensor_view( + // aq_ptr, + // make_tuple(kargs.M, kargs.QK), + // make_tuple(kargs.stride_AQ, 1), + // number{}, + // number<1>{}); + // }(); + const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -452,6 +484,17 @@ struct GemmMXKernel } }(); + // B scale tensor view + const auto& b_scale_tensor_view = [&]() { + static_assert(std::is_same_v); + return make_naive_tensor_view( + b_scale_ptr, + make_tuple(kargs.N, kargs.K / BlockScaleSize), + make_tuple(kargs.stride_scale_B, 1), + number{}, + number<1>{}); + }(); + // TODO: enable vector write for C in ColMajor const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) @@ -474,7 +517,8 @@ struct GemmMXKernel } }(); - return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view); + return make_tuple( + a_tensor_view, a_scale_tensor_view, b_tensor_view, b_scale_tensor_view, c_tensor_view); } template @@ -498,13 +542,13 @@ struct GemmMXKernel } }(); - const auto& aq_pad_view = [&]() { - const auto& aq_tensor_view = views.at(I1); + const auto& a_scale_pad_view = [&]() { + const auto& a_scale_tensor_view = views.at(I1); static_assert(std::is_same_v); return pad_tensor_view( - aq_tensor_view, + a_scale_tensor_view, make_tuple(number{}, - number{}), + number{}), // TODO: Add support for padding. sequence{}); }(); @@ -527,9 +571,19 @@ struct GemmMXKernel } }(); + const auto& b_scale_pad_view = [&]() { + const auto& b_scale_tensor_view = views.at(I3); + static_assert(std::is_same_v); + return pad_tensor_view( + b_scale_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + // TODO vector write in for C in ColMajor const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I3); + const auto& c_tensor_view = views.at(I4); if constexpr(std::is_same_v) { return pad_tensor_view(c_tensor_view, @@ -546,17 +600,18 @@ struct GemmMXKernel } }(); - return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view); + return make_tuple(a_pad_view, a_scale_pad_view, b_pad_view, b_scale_pad_view, c_pad_view); } template CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { - const auto& a_pad_view = views.at(I0); - const auto& aq_pad_view = views.at(I1); - const auto& b_pad_view = views.at(I2); - const auto& c_pad_view = views.at(I3); + const auto& a_pad_view = views.at(I0); + const auto& a_scale_pad_view = views.at(I1); + const auto& b_pad_view = views.at(I2); + const auto& b_scale_pad_view = views.at(I3); + const auto& c_pad_view = views.at(I4); const auto& a_block_window = [&]() { if constexpr(std::is_same_v) @@ -575,12 +630,12 @@ struct GemmMXKernel } }(); - const auto& aq_block_window = [&]() { - static_assert(std::is_same_v); + const auto& a_scale_block_window = [&]() { + static_assert(std::is_same_v); return make_tile_window( aq_pad_view, make_tuple(number{}, - number{}), + number{}), {i_m, 0}); }(); @@ -601,6 +656,15 @@ struct GemmMXKernel } }(); + const auto& b_scale_block_window = [&]() { + static_assert(std::is_same_v); + return make_tile_window( + b_scale_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + }(); + auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), @@ -626,8 +690,9 @@ struct GemmMXKernel */ template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const AScaleDataType* a_scale_ptr, const BDataType* b_ptr, - const AQDataType* aq_ptr, + const BScaleDataType* b_scale_ptr, CDataType* c_ptr, void* smem_ptr_0, const AQuantGemmKernelArgs& kargs, @@ -670,16 +735,26 @@ struct GemmMXKernel const SplitKBatchOffset splitk_batch_offset(kargs); // options - const ADataType* a_ptr = static_cast(kargs.a_ptr); - const BDataType* b_ptr = static_cast(kargs.b_ptr); - const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); - CDataType* c_ptr = static_cast(kargs.c_ptr); + const ADataType* a_ptr = static_cast(kargs.a_ptr); + const AScaleDataType* a_scale_ptr = static_cast(kargs.a_scale_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr); + const BScaleDataType* b_scale_ptr = static_cast(kargs.b_scale_ptr); + CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; assert(kargs.k_batch == 1); - RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + RunGemm(a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } }; diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp index 3c165e2e91..0c868be652 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp @@ -28,7 +28,20 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize; static_assert(std::is_same_v); - return GetAScaleGlobalVectorLoadSize(); + return GetScaleGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBScale() + { + using BScaleLayout = remove_cvref_t; + using BScaleDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize; + + static_assert(std::is_same_v); + return GetScaleGlobalVectorLoadSize(); } template diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index e2fdd7d443..fa5c763d9c 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -125,9 +125,14 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } - static constexpr index_t GetVectorSizeScale() + + static constexpr index_t GetVectorSizeAScale() { - return Policy::template GetVectorSizeAQ(); + return Policy::template GetVectorSizeAScale(); + } + static constexpr index_t GetVectorSizeBScale() + { + return Policy::template GetVectorSizeBScale(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp index 9d7a8abaa2..b2c46ea46b 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_utils.hpp @@ -8,7 +8,7 @@ namespace ck_tile { template -CK_TILE_HOST_DEVICE static constexpr auto GetAScaleGlobalVectorLoadSize() +CK_TILE_HOST_DEVICE static constexpr auto GetScaleGlobalVectorLoadSize() { using I1 = number<1>; constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});