diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index 951e6d146f..ce49dea098 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -138,6 +138,48 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle LDSTypeB>; using Argument = typename GridwiseGemm::Argument; + struct DeviceArgument : public Argument + { + __host__ DeviceArgument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + index_t Nr_, + index_t Kr_) + : Argument{p_a_grid_, + p_b_grid_, + p_ds_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideC_, + k_batch_, + a_element_op_, + b_element_op_, + c_element_op_}, + Nr{Nr_}, + Kr{Kr_} + { + } + + index_t Nr; + index_t Kr; + }; int GetPreShuffleParameters() override { return NPerXDL; } @@ -540,7 +582,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle { return false; } + const auto karg = dynamic_cast(&arg); + if(NPadding && (karg->Nr != GridwiseGemm::CalculateBNShufflePadded(arg.N))) + { + return false; + } + if(KPadding && (karg->Kr != GridwiseGemm::CalculateBKShufflePadded(arg.K))) + { + return false; + } return GridwiseGemm::CheckValidity(arg); } @@ -568,23 +619,23 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle index_t Nr, index_t Kr) { - return Argument{static_cast(p_a), - static_cast(p_b), - p_ds, - static_cast(p_c), - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op, - Nr, - Kr}; + return DeviceArgument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op, + Nr, + Kr}; } static auto MakeInvoker() { return Invoker{}; } @@ -608,23 +659,23 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle index_t Nr, index_t Kr) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - p_ds, - static_cast(p_c), - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op, - Nr, - Kr); + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op, + Nr, + Kr); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index c9c314b7b4..7b942eca0d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -596,9 +596,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle index_t StrideB_, std::array StrideDs_, index_t StrideC_, - index_t KBatch_, - index_t Nr_, - index_t Kr_) + index_t KBatch_) : M{M_}, N{N_}, K{K_}, @@ -616,9 +614,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle MBlock{CalculateMBlock(M_)}, NBlock{CalculateNBlock(N_)}, BN0Shuffled{CalculateBN0Shuffled(NPadding ? CalculateBNShufflePadded(N_) : N_)}, - BK0Shuffled{CalculateBK0Shuffled(KPadding ? CalculateBKShufflePadded(K_) : K_)}, - Nr{Nr_}, - Kr{Kr_} + BK0Shuffled{CalculateBK0Shuffled(KPadding ? CalculateBKShufflePadded(K_) : K_)} { } @@ -660,8 +656,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // FOR PRESHUFFLE ONLY index_t BN0Shuffled; index_t BK0Shuffled; - index_t Nr; - index_t Kr; }; // Argument @@ -681,10 +675,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_, - index_t Nr_, - index_t Kr_) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_, Nr_, Kr_}, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, p_ds_grid{}, @@ -952,16 +944,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle return false; } - if(NPadding && (karg.Nr != CalculateBNShufflePadded(karg.N))) - { - return false; - } - - if(KPadding && (karg.Kr != CalculateBKShufflePadded(karg.K))) - { - return false; - } - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp index 5e95be1f75..605c247694 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp @@ -138,6 +138,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; std::cout << "rotating count: " << rotating_count << std::endl; + std::cout << "verification: " << do_verification << std::endl; switch(init_method) { @@ -325,7 +326,10 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, << std::endl; } } - + if(!pass) + { + continue; + } std::string op_name = op_ptr->GetTypeString(); float ave_time = invoker_ptr->Run(argument_ptr.get(),