From 23b2fb0f3e4f40429ef46a367d1e34657f6d5dd2 Mon Sep 17 00:00:00 2001 From: qin letao Date: Sat, 22 Feb 2025 09:25:19 +0000 Subject: [PATCH] add padding info --- ...y_multiply_xdl_fp8_bpreshuffle_padding.cpp | 9 ++-- ...ultiple_d_xdl_cshuffle_v3_b_preshuffle.hpp | 7 ++- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 46 +++++++++++-------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp index 9ca9fadc48..caa0a9d87c 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp @@ -242,12 +242,12 @@ int main(int argc, char* argv[]) return HostTensorDescriptor({row, col}, {1_uz, stride}); } }; - auto Knew = (K + 64 - 1) / 64; - auto StrideBnew = Knew; + auto Knew = (K + 64 - 1) / 64 * 64; + auto StrideBnew = Knew; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b0_preshuffled( - f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + f_host_tensor_descriptor(Knew, N, StrideBnew, B0Layout{})); // use laout only for size Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -281,7 +281,7 @@ int main(int argc, char* argv[]) d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_preshuffled.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); @@ -304,6 +304,7 @@ int main(int argc, char* argv[]) int NPerXdl = device_op.GetPreShuffleParameters(); + b0_preshuffled.SetZero(); preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); b0_device_buf.ToDevice(b0_preshuffled.mData.data()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index 1177854983..5974ec2a00 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -355,6 +355,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle throw std::runtime_error("todo: only v1 v2 and v3 support now"); } } + else + { + throw std::runtime_error("not call kernel function"); + } #if 0 else { @@ -526,7 +530,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle return false; } - if(arg.N % NPerBlock != 0 || (arg.K % KPerBlock != 0 && GemmSpec != GemmSpecialization::KPadding)) + if(arg.N % NPerBlock != 0 || + (arg.K % KPerBlock != 0 && GemmSpec != GemmSpecialization::KPadding)) { return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index a57edbf39b..8242339d38 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -349,27 +349,33 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } } - __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + __host__ __device__ static auto + MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0, index_t KPad) { - // using GemmSpecialization = tensor_operation::device::GemmSpecialization; - //if K padding - // if constexpr(GemmSpec == GemmSpecialization::KPadding || - // GemmSpec == GemmSpecialization::NKPadding) + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + // if K padding + if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) { // origin: [N0,K0,KLane,NLane,KPack] - // const auto b_grid_desc_raw = make_naive_tensor_descriptor( - // make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), - // make_tuple( - // NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); - // const auto b_grid_desc_n_k = - // transform_tensor_descriptor(b_grid_desc_nraw_kraw, - // make_tuple(make_pass_through_transform(N), - // make_right_pad_transform(K, KPad - K)), - // make_tuple(Sequence<0>{}, Sequence<1>{}), - // make_tuple(Sequence<0>{}, Sequence<1>{})); - // ignore = b_grid_desc_n_k; + constexpr index_t NkSwizzleNumber = Number{}; + const auto b_grid_desc_raw = make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple( + NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + + auto K0new = CalculateBK0Shuffled(KPad); + + return transform_tensor_descriptor( + b_grid_desc_raw, + make_tuple(make_pass_through_transform(N0 / NWave), + make_pass_through_transform(NWave), + make_right_pad_transform(K0, K0new - K0), + make_pass_through_transform(NkSwizzleNumber)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); } - // else + else { constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( @@ -591,7 +597,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle MBlock{CalculateMBlock(M_)}, NBlock{CalculateNBlock(N_)}, BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + BK0Shuffled{CalculateBK0Shuffled((K_ + 64 - 1) / 64 * 64)} { } @@ -1592,8 +1598,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled( + problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);