diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp index cc107b63dc..b36bd761b3 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp @@ -14,18 +14,22 @@ using ComputeDataType = float; struct YElementOp { - template - __host__ __device__ void operator()(T& y, const T& x) const + template + __host__ __device__ void operator()(Y& y, const X& x) const { - static_assert(ck::is_same::value || ck::is_same::value || - ck::is_same::value, + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, "Data type is not supported by this operation!"); - T a; + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + X a; ck::tensor_operation::element_wise::Sigmoid{}(a, x); - y = x * a; + y = ck::type_convert(x * a); }; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index c90c28f5a8..d98725cf9d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -144,7 +144,8 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeDataType = EDataType> struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); - using ComputeDataType = EDataType; - // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype + ADataType, BDataType, ComputeDataType, AccDataType, diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 34ac08b665..69a6540a0c 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -27,6 +27,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(float& y, const double& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(float& y, const float& x) const { @@ -81,6 +87,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(int8_t& y, const float& x) const + { + y = type_convert(x); + } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> __host__ __device__ void operator()(int4_t& y, const int4_t& x) const @@ -416,14 +428,19 @@ struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} - template - __host__ __device__ void operator()(T& y, const T& x) const + template + __host__ __device__ void operator()(Y& y, const X& x) const { - static_assert(is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); - y = x / (ck::type_convert(1) + ck::math::exp(-beta_ * x)); + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); }; float beta_ = 1.0f; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 605f2569c6..2774214079 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); - // apply type convert - dst_vector.template AsType()(i) = type_convert(v); + dst_vector.template AsType()(i) = v; }); const bool is_dst_valid = @@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); // apply type convert - dst_buf(Number{}) = type_convert(v); + dst_buf(Number{}) = v; }); }); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 309b4afad8..95bd1e13d9 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -20,7 +20,8 @@ template + typename CElementwiseOperation, + typename ComputType = ADataType> struct ReferenceGemm : public device::BaseOperator { // Argument @@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator for(int k = 0; k < K; ++k) { - ADataType v_a; - BDataType v_b; + ComputType v_a; + ComputType v_b; // use PassThrough instead of ConvertBF16RTN for reference calculation if constexpr(is_same_v{-1, 1}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 0.2}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); }