diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp index 7dfc00ea67..b7a9f78b5d 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle_padding.cpp @@ -98,7 +98,7 @@ struct MultiplyMultiply } }; -template +template void preShuffleBuffer(const DataType* src, DataType* dst, int N, int K, int NXdl, int Knew) { int KPack = 16; @@ -148,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, B0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, B0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -243,7 +243,8 @@ int main(int argc, char* argv[]) return HostTensorDescriptor({row, col}, {1_uz, stride}); } }; - auto Knew = (K + 64 - 1) / 64 * 64; + auto device_op = DeviceOpInstance{}; + auto Knew = device_op.GetPreShufflePadded(K); auto StrideBnew = Knew; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); @@ -301,7 +302,6 @@ int main(int argc, char* argv[]) constexpr auto I0 = ck::Number<0>{}; // do GEMM - auto device_op = DeviceOpInstance{}; int NPerXdl = device_op.GetPreShuffleParameters(); @@ -341,7 +341,7 @@ int main(int argc, char* argv[]) } #endif - b0_preshuffled.SetZero(); //must set to zero + b0_preshuffled.SetZero(); // must set to zero preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl, Knew); b0_device_buf.ToDevice(b0_preshuffled.mData.data()); diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 403a1cb085..9006e70040 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -138,7 +138,8 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; - virtual int GetPreShuffleParameters() = 0; + virtual int GetPreShuffleParameters() = 0; + virtual int GetPreShufflePadded(int K) = 0; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index 5974ec2a00..bce75f8f34 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -140,6 +140,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle using Argument = typename GridwiseGemm::Argument; int GetPreShuffleParameters() override { return NPerXDL; } + int GetPreShufflePadded(int K) override { return (K + 64 - 1) / 64 * 64; } // Invoker struct Invoker : public BaseInvoker diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 5784bbd3d9..0ecf70f5b6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1161,8 +1161,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded); + const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled( + problem.BN0Shuffled, problem.BK0Shuffled, problem.KPadded); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);