From f2a474e2e9d8adaa0a7cc73b88f6136aae66ed70 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Fri, 9 May 2025 11:04:39 +0800 Subject: [PATCH] fix update --- .../gemm_mx_fp8_bpreshuffle.cpp | 73 ++++++++----------- ...emm_pipeline_xdlops_b_preshuflle_v1_mx.hpp | 12 +-- .../gpu/device/device_gemm_mx.hpp | 2 + ...e_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp | 54 +++++++------- ...e_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp | 8 +- 5 files changed, 68 insertions(+), 81 deletions(-) diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp index 3ff88b6c7f..cc6300661b 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp @@ -7,6 +7,7 @@ #include #include "ck/ck.hpp" +#include "ck/library/utility/literals.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -21,12 +22,18 @@ #include "ck/library/utility/fill.hpp" #include "ck/library/utility/host_tensor.hpp" +template +using S = ck::Sequence; + using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; using XDataType = ck::e8m0_bexp_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + using A0DataType = F8; using A1DataType = XDataType; using B0DataType = F8; @@ -40,7 +47,7 @@ using A0Layout = Row; using B0Layout = Col; using CLayout = Row; -void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) +void preShuffleBuffer(const F8* src, F8* dst, int N, int K, int NXdl) { int KPack = 16; int NLane = NXdl; @@ -71,6 +78,8 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) } } +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; // elementwise transformation for A matrix using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix @@ -92,7 +101,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ADataType, BDataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType, B0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -135,7 +144,7 @@ int main(int argc, char* argv[]) StrideA = K; StrideB = K; - StrideE = N; + StrideC = N; } else { @@ -147,8 +156,8 @@ int main(int argc, char* argv[]) exit(0); } - ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; - ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -166,19 +175,19 @@ int main(int argc, char* argv[]) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor a_m_k_scale(f_host_tensor_descriptor( - M, (K + Scale_Block_K - 1) / Scale_Block_K, Scale_Stride_AM, A0Layout{})); + M, (K + ScaleBlockSize - 1) / ScaleBlockSize, Scale_Stride_AM, A0Layout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b_preshuffled(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b_k_n_scale(f_host_tensor_descriptor( - (K + Scale_Block_K - 1) / Scale_Block_K, N, Scale_Stride_BN, B0Layout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, ELayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, ELayout{})); + (K + ScaleBlockSize - 1) / ScaleBlockSize, N, Scale_Stride_BN, B0Layout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; - std::cout << "b_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + std::cout << "e_m_n: " << c_m_n_host_result.mDesc << std::endl; switch(init_method) { @@ -206,9 +215,9 @@ int main(int argc, char* argv[]) DeviceMem a_scale_device_buf(sizeof(A1DataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(B0DataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_scale_device_buf(sizeof(B1DataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); - a_device_buf.ToDevice(a0_m_k.mData.data()); + a_device_buf.ToDevice(a_m_k.mData.data()); a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); @@ -226,9 +235,7 @@ int main(int argc, char* argv[]) auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - constexpr ck::index_t NumDTensor = DsDataType::Size(); + auto cde_element_op = CElementOp{}; // do GEMM auto device_op = DeviceOpInstance{}; @@ -243,7 +250,7 @@ int main(int argc, char* argv[]) static_cast(a_scale_device_buf.GetDeviceBuffer()), static_cast(b_device_buf.GetDeviceBuffer()), static_cast(b_scale_device_buf.GetDeviceBuffer()), - static_cast(e_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), M, N, K, @@ -252,7 +259,7 @@ int main(int argc, char* argv[]) StrideB, Scale_Stride_BN, StrideC, - KBatch, + 1, // KBatch a_element_op, b_element_op, cde_element_op); @@ -292,30 +299,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_m_n({M, N}); - Tensor a_m_k({M, K}); - Tensor b_k_n({K, N}); - - for(int m = 0; m < M; m++) - { - for(int k = 0; k < K; k++) - { - a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * - a1_m_k(m / Scale_Block_M, k / Scale_Block_K); - } - } - - for(int n = 0; n < N; n++) - { - for(int k = 0; k < K; k++) - { - b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * - b1_k_n(k / Scale_Block_K, n / Scale_Block_N); - } - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm{}) = + a_scale_thread_bufs(I0)(Number{}) = a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, @@ -318,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx{}) = + b_scale_thread_bufs(I0)(Number{}) = b_scale_thread_buf_copy[Number<0>{}]; b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, @@ -358,7 +358,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx{}) = + a_scale_thread_bufs(I1)(Number{}) = a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, @@ -388,7 +388,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx{}) = + b_scale_thread_bufs(I1)(Number{}) = b_scale_thread_buf_copy[Number<0>{}]; b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, @@ -542,7 +542,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx{}) = + a_scale_thread_bufs(mfma_reg_buf)(Number{}) = a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, @@ -573,7 +573,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx{}) = + b_scale_thread_bufs(mfma_reg_buf)(Number{}) = b_scale_thread_buf_copy[Number<0>{}]; b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp index 2e70838ca1..0562e452ac 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp @@ -79,6 +79,8 @@ struct DeviceGemmMX_BPreshuffle : public BaseOperator CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp index 3ed1ad2195..b7dd4e3f5f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp @@ -216,6 +216,8 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle using Argument = typename GridwiseGemm::Argument; + int GetPreShuffleParameters() override { return NPerXDL; } + // Invoker struct Invoker : public BaseInvoker { @@ -313,32 +315,32 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; Run(kernel); } else { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; Run(kernel); } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -350,21 +352,21 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; Run(kernel); } else { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; Run(kernel); } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp index 55e04996dc..bb40386fca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp @@ -141,7 +141,7 @@ template -struct GridwiseGemmMX_xdl_cshuffle_v3 +struct GridwiseGemmMX_xdl_cshuffle_v3_b_preshuffle { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -172,7 +172,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 ComputeTypeB, is_single_rate_mfma, is_scale_mfma>; - static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma_selector::k_per_blk); + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk); static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; static constexpr index_t KLane = @@ -1227,7 +1227,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; + // const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1417,7 +1417,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 "wrong!"); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = @@ -1906,7 +1905,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 "wrong!"); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =