diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp index d38947e7a3..676afbe201 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp @@ -129,7 +129,16 @@ void preShuffleBuffer(const DataType* src, DataType* dst, int N, int K, int NXdl } } } -int GetPreShufflePadded(int K) { return (K + ShufflePadded - 1) / ShufflePadded * ShufflePadded; } + +int GetKPreShufflePadded(int K) +{ + return (K + KShufflePadded - 1) / KShufflePadded * KShufflePadded; +} +int GetNPreShufflePadded(int N) +{ + return (N + NShufflePadded - 1) / NShufflePadded * NShufflePadded; +} + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; @@ -268,12 +277,14 @@ int main(int argc, char* argv[]) } }; - auto Knew = GetPreShufflePadded(K); + auto Knew = GetKPreShufflePadded(K); auto StrideBnew = Knew; + auto Nnew = GetNPreShufflePadded(N); + std::cout << "Knew: " << Knew << " Nnew: " << Nnew << std::endl; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor b0_preshuffled( - f_host_tensor_descriptor(Knew, N, StrideBnew, B0Layout{})); // use laout only for size + f_host_tensor_descriptor(Knew, Nnew, StrideBnew, B0Layout{})); // use laout only for size Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -289,8 +300,8 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 0875a2a640..c303c64888 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -146,8 +146,8 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator virtual int GetPreShuffleParameters() = 0; }; - -#define ShufflePadded 256 +#define KShufflePadded 256 +#define NShufflePadded 128 } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 20cf6d79da..eacf91c371 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -214,7 +214,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto CalculateBKShufflePadded(index_t K) { - return (K + ShufflePadded - 1) / ShufflePadded * ShufflePadded; + return (K + KShufflePadded - 1) / KShufflePadded * KShufflePadded; + } + + __host__ __device__ static auto CalculateBNShufflePadded(index_t N) + { + return (N + NShufflePadded - 1) / NShufflePadded * NShufflePadded; } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -604,7 +609,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled((N + 128 - 1) / 128 * 128)}, + BN0Shuffled{CalculateBN0Shuffled(CalculateBNShufflePadded(N_))}, BK0Shuffled{CalculateBK0Shuffled(CalculateBKShufflePadded(K_))} { } diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp index c7209e3c8d..ad09a97766 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp @@ -57,7 +57,14 @@ void preShuffleBuffer( } } -int GetPreShufflePadded(int K) { return (K + ShufflePadded - 1) / ShufflePadded * ShufflePadded; } +int GetKPreShufflePadded(int K) +{ + return (K + KShufflePadded - 1) / KShufflePadded * KShufflePadded; +} +int GetNPreShufflePadded(int N) +{ + return (N + NShufflePadded - 1) / NShufflePadded * NShufflePadded; +} template a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_preshuffled_mfma16( - f_host_tensor_descriptor(Knew, N, StrideBnew, BLayout{})); // use layout only for size + f_host_tensor_descriptor(Knew, Nnew, StrideBnew, BLayout{})); // use layout only for size Tensor b_preshuffled_mfma32( - f_host_tensor_descriptor(Knew, N, StrideBnew, BLayout{})); // use layout only for size + f_host_tensor_descriptor(Knew, Nnew, StrideBnew, BLayout{})); // use layout only for size Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));