From bb043a3202eed30ba6bf2eecdfe07019a94ba0ec Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 9 May 2025 07:54:28 +0000 Subject: [PATCH] remove some unnecessary hacky; enable 256x256x256 tilesize --- example/67_gemm_microscaling/gemm_mx_fp4.cpp | 31 ++++++----- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 51 +++++++++---------- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/example/67_gemm_microscaling/gemm_mx_fp4.cpp b/example/67_gemm_microscaling/gemm_mx_fp4.cpp index a3a6cd2e9c..b923290f54 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4.cpp @@ -23,12 +23,15 @@ using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 128; +constexpr ck::index_t KPerBlock = 256; constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +// AB DataType: f4x2_pk_t +// Mathmatically, all numbers are represented as f4. using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, // ALayout BLayout, // BLayout @@ -45,29 +48,29 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle CElementOp, // CElementwiseOperation GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size - 256, // BlockSize: Thread block size - 128, // MPerBlock - 128, // NPerBlock + 256, // BlockSize: Thread block size + 256, // MPerBlock + 256, // NPerBlock KPerBlock, // KPerBlock - 16, // AK1 - 16, // BK1 + 32, // AK1 + 32, // BK1 16, // MPerXDL 16, // NPerXDL - 4, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 8, // MXdlPerWave + 8, // NXdlPerWave + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 + 32, // ABlockTransferSrcScalarPerVector + 32, // ABlockTransferDstScalarPerVector_AK1 false, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 + 32, // BBlockTransferSrcScalarPerVector + 32, // BBlockTransferDstScalarPerVector_BK1 false, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index c544180545..55f2b72cd3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -154,10 +154,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static constexpr auto I7 = Number<7>{}; // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = false; @@ -175,8 +175,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 NPerXdl, ComputeTypeB, is_single_rate_mfma, - is_scale_mfma>::selected_mfma.k_per_blk / - 2); + is_scale_mfma>::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; @@ -295,7 +294,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), make_pass_through_transform(MPad)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -308,7 +307,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // pad M, but not K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), make_right_pad_transform(M, MPad - M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -327,7 +326,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -339,7 +338,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // not pad M or K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -384,7 +383,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), make_pass_through_transform(NPad)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -397,7 +396,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // pad N, but not K const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), make_right_pad_transform(N, NPad - N)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -416,7 +415,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), make_pass_through_transform(N)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -430,7 +429,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // not pad N or K const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)), + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), make_pass_through_transform(N)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -765,7 +764,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // in some cases. else if constexpr(is_same::value) { - constexpr index_t LdsSize = 32 * 4 / (KPerBlock / APackedSize) / sizeof(ADataType); + constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(ADataType) / APackedSize); constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( make_tuple( @@ -901,7 +900,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 else if constexpr(is_same::value) { // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / (KPerBlock / BPackedSize) / sizeof(BDataType); + constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(BDataType)/ BPackedSize) ; constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( make_tuple( @@ -1416,8 +1415,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 Sequence<0, 1, 2>, ABlockTransferSrcVectorDim, 2, - ABlockTransferSrcScalarPerVector * 2, - ABlockTransferDstScalarPerVector_AK1 * 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, 1, 1, AThreadTransferSrcResetCoordinateAfterRun, @@ -1447,8 +1446,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 Sequence<0, 1, 2>, BBlockTransferSrcVectorDim, 2, - BBlockTransferSrcScalarPerVector * 2, - BBlockTransferDstScalarPerVector_BK1 * 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, 1, 1, BThreadTransferSrcResetCoordinateAfterRun, @@ -1467,14 +1466,14 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()/APackedSize); CK_PRINT>(); auto b_block_buf = make_dynamic_buffer( reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + b_block_desc_bk0_n_bk1.GetElementSpaceSize()/BPackedSize); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); @@ -1912,8 +1911,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 Sequence<0, 1, 2>, ABlockTransferSrcVectorDim, 2, - ABlockTransferSrcScalarPerVector * 2, - ABlockTransferDstScalarPerVector_AK1 * 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, 1, 1, AThreadTransferSrcResetCoordinateAfterRun, @@ -1943,8 +1942,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 Sequence<0, 1, 2>, BBlockTransferSrcVectorDim, 2, - BBlockTransferSrcScalarPerVector * 2, - BBlockTransferDstScalarPerVector_BK1 * 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, 1, 1, BThreadTransferSrcResetCoordinateAfterRun,