diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 81d58b509b..79aeb03e91 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -34,11 +34,11 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for NT problem // clang-format off using DeviceGemmInstance = - //#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; + //#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; // clang-format on template 0 ? a : 0; float c = b + v2; return c; @@ -52,70 +52,13 @@ struct BiasReluAdd } }; -// v0 is from A * B -// v1 is from C0 -// v2 is from C1 -struct BiasLeakyReluAdd -{ - template - __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { - float a = v0 + v1; - float b = 0.1 * a; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; - } - - template - __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { - constexpr float alpha = 0.1; - constexpr float alpha_inv = 1.0 / alpha; - - float a = v2 * alpha_inv; - float b = v1 + v0; - float c = max(b, float(0)); - float d = alpha * (a + c); - - return d; - } -}; - -struct BiasLeakyRelu -{ - template - __host__ constexpr float operator()(float v0, T1 v1, T2) const - { - float a = v0 + v1; - float b = 0.1 * a; - float c = b > 0 ? b : 0; - - return c; - } - - template - __device__ constexpr float operator()(float v0, T1 v1, T2) const - { - constexpr float alpha = 0.1; - - float b = v1 + v0; - float c = max(b, float(0)); - float d = alpha * c; - - return d; - } -}; - -struct BiasAdd +struct DoSomething { #if 1 // correct result // no scratch memory, good VGPR allocation (59) - // good perf (101Tflops) - template - __host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const + // good perf (101Tflops @ 1089Mhz) + __host__ __device__ constexpr float operator()(float v0, ck::half_t v1, ck::half_t v2) const { constexpr float alpha = 0.1; constexpr float beta = 0.2; @@ -124,7 +67,7 @@ struct BiasAdd // compiler seems very volatile to the order of these calculation: // compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register // over-allocation. Therefore, move v0 calculation to the very end - float a = T1(beta) * v1 + T2(gamma) * v2; + float a = ck::half_t(beta) * v1 + ck::half_t(gamma) * v2; float b = a + float(alpha) * v0; return b; @@ -137,15 +80,14 @@ struct BiasAdd // wrong result // lots of scratch memory // huge perf drop - template - __host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const + __host__ __device__ constexpr float operator()(float v0, ck::half_t v1, ck::half_t v2) const { return alpha * v0 + beta * v1 + gamma * v2; } #elif 0 // correct result // some scratch memory (68 dword) - // some perf drop (94Tflops) + // some perf drop (94Tflops @ 1089MHz) // fp64 instructions are used __host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const { @@ -185,16 +127,20 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; using AOp = PassThrough; using BOp = PassThrough; +#if 1 using COp = BiasReluAdd; +#else +using COp = DoSomething; +#endif // Compilation parameters for NT problem // clang-format off using DeviceGemmInstance = - //#################################################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //#################################################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //#################################################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //#################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl_two_extra_source_reduce< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; + //#################################################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################################################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################################################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl_two_extra_source_reduce< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; // clang-format on template & a_m_k, auto f_mk_kn_mn = [&](auto m, auto n) { const int K = a_m_k.mDesc.GetLengths()[1]; - double v = 0; + float acc = 0; for(int k = 0; k < K; ++k) { - v += static_cast(a_element_op(a_m_k(m, k))) * - static_cast(b_element_op(b_k_n(k, n))); + acc += static_cast(a_element_op(a_m_k(m, k))) * + static_cast(b_element_op(b_k_n(k, n))); } - c_m_n(m, n) = c_element_op( - v, static_cast(c0_m_n(m, n)), static_cast(c1_m_n(m, n))); + c_m_n(m, n) = c_element_op(acc, c0_m_n(m, n), c1_m_n(m, n)); }; make_ParallelTensorFunctor(f_mk_kn_mn, @@ -249,9 +194,9 @@ int main(int argc, char* argv[]) if(argc == 4) { - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); } else if(argc == 10) { @@ -337,7 +282,9 @@ int main(int argc, char* argv[]) c0_m_n_device_buf.ToDevice(c0_m_n.mData.data()); c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); - auto c_element_op = BiasReluAdd{}; + auto a_element_op = AOp{}; + auto b_element_op = BOp{}; + auto c_element_op = COp{}; // do GEMM auto gemm = DeviceGemmInstance{}; @@ -354,8 +301,8 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, - PassThrough{}, - PassThrough{}, + a_element_op, + b_element_op, c_element_op); if(!gemm.IsSupportedArgument(argument)) diff --git a/example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp b/example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp index 1948d80584..ce8ea79bd6 100644 --- a/example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp +++ b/example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp @@ -35,24 +35,22 @@ template + ck::index_t CThreadTransferDstScalarPerVector> struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator { static constexpr auto I0 = Number<0>{}; @@ -137,45 +135,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator using C1GridDesc_M_N = decltype(make_naive_tensor_descriptor(make_tuple(1, 1), make_tuple(I1, I0))); - // TODO remove these hacks - static constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 - - static constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 - - static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - - static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - - static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; - // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5< BlockSize, @@ -199,7 +158,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator K1, MXdlPerWave, NXdlPerWave, - ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -207,25 +165,18 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, - BBlockTransferThreadSliceLengths_K0_N_K1, + ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks, - decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks, - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, - decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, - decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, - false, // CAccessOrderMRepeatNRepeat, - ABlockLdsAddExtraM, - BBlockLdsAddExtraN>; + CThreadTransferDstScalarPerVector>; using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));