diff --git a/example/66_gemm_bias_activation/gemm_bias_add.hpp b/example/66_gemm_bias_activation/gemm_bias_add.hpp index f3183ccb66..467b388c3f 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add.hpp +++ b/example/66_gemm_bias_activation/gemm_bias_add.hpp @@ -6,6 +6,18 @@ #include "ck/ck.hpp" #include "ck/stream_config.hpp" +enum class ActivationType +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + GeluNoneApproximate, + GeGluNoneApproximate, + InvalidType +}; struct GemmBiasAddArgs { const void* mat_a; diff --git a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp index 3475c6c6c4..040067f80e 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp +++ b/example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp @@ -38,6 +38,49 @@ using S = ck::Sequence; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +namespace ck { +namespace impl { +template +struct AddActivation +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + Activation{}.template operator()(y, x0 + x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + float x = x0 + type_convert(x1); + Activation{}.template operator()(y, x); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + float result = 0; + Activation{}.template operator()(result, x0 + x1); + y = type_convert(result); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + float result = 0; + Activation{}.template operator()(result, x0 + x1); + y = type_convert(result); + }; +}; +} // namespace impl +} // namespace ck // clang-format off template using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< diff --git a/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp b/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp index b5c3bb2a02..704e448319 100644 --- a/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp +++ b/example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp @@ -13,6 +13,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -38,11 +39,11 @@ using DsLayout = ck::Tuple; using ELayout = Row; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -// using Add = ck::tensor_operation::element_wise::Add; +using Add = ck::tensor_operation::element_wise::Add; using AElementOp = PassThrough; using BElementOp = PassThrough; -using CElementOp = PassThrough; +using CElementOp = Add; using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + PassThrough>; +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} int main(int argc, char* argv[]) { bool do_verification = true; @@ -63,11 +144,6 @@ int main(int argc, char* argv[]) ck::index_t N = 16; ck::index_t K = 64; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; - if(argc == 1) { // use default case @@ -78,7 +154,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 11) + else if(argc == 7) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -87,21 +163,21 @@ int main(int argc, char* argv[]) M = std::stoi(argv[4]); N = std::stoi(argv[5]); K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x)m\n"); exit(0); } + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; @@ -132,12 +208,12 @@ int main(int argc, char* argv[]) case 1: a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_1{0}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_1{0}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); @@ -183,13 +259,28 @@ int main(int argc, char* argv[]) auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( - a0_m_k, b0_k_n, e_m_n_host_result, AElementOp{}, BElementOp{}, CElementOp{}); + a0_m_k, b0_k_n, e_m_n_host_result, AElementOp{}, BElementOp{}, PassThrough{}); ref_invoker.Run(ref_argument); + CElementOp cde_element_op; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), e_m_n_host_result(m, n), d0_m_n(m, n)); + } + } + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + return ck::utils::check_err(e_m_n_device_result, + e_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()) + ? 0 + : 1; } return 0; diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index c87c90a91d..8f09fb6350 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -33,7 +33,7 @@ struct Add __host__ __device__ constexpr void operator()(float& y, const float& x0, const half_t& x1) const { - y = x0 + type_convert(x1); + y = x0 + type_convert(x1); }; template <> diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 8079b04b84..071947e4b0 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1077,6 +1077,7 @@ struct ConvScaleRelu float scale_out_; }; + // support fastconvert of int8 to fp16 template