diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index cb20ea2492..1626597ed2 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -175,7 +175,10 @@ // denorm test fix, required to work around dissue #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 -#endif +#elif +// enable only on MI200 +#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) +#endif // CK_WORKAROUND_DENORM_FIX namespace ck { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 2987def02a..ef250b8bfd 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -56,6 +56,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { @@ -86,6 +92,23 @@ struct UnaryConvert } }; +struct ConvertBF16RTN +{ + // convert to bf16 using round to nearest (rtn) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = bf16_convert_rtn(x); + } +}; + struct Scale { __host__ __device__ Scale(float scale) : scale_(scale) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 98a71a7c24..ec1cc53991 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update // ABDataTypeAdjusted -> ABDataType throughout this file -#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) +#if CK_WORKAROUND_DENORM_FIX using ABDataTypeAdjusted = conditional_t, ck::bhalf_t, ABDataType>; #else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 1979331d07..da7ad1cacf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -266,7 +266,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update // FloatABAdjusted -> FloatAB throughout this file -#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) +#if CK_WORKAROUND_DENORM_FIX using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; #else using FloatABAdjusted = FloatAB; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 775b77118c..f4504a9402 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -136,7 +136,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update // FloatABAdjusted -> FloatAB throughout this file -#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) +#if CK_WORKAROUND_DENORM_FIX using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; #else using FloatABAdjusted = FloatAB; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index cba06f8e87..6665d765f8 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -6,6 +6,7 @@ #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" @@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 auto src_vector_container = src_vector_type{ src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; - // apply SrcElementwiseOperation on src_vector_container - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - SrcData src_v; - - src_element_op_(src_v, src_vector_container.template AsType()[i]); - - src_vector_container.template AsType()(i) = src_v; - }); - // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType( @@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); - // TODO type_convert is not used yet!!!!! using src_vector_t = vector_type_maker_t; using dst_vector_t = vector_type_maker_t; @@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Number{}); // do data transpose - // TODO type_convert is not used yet!!!!! transpose_vectors{}( src_vector_refs, dst_vector_refs); }); } - else - { - static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = - type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); - }); - } + + static_ford{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if needed + DstData dst_v; + src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); + dst_thread_scratch_(idx) = dst_v; + }); #endif } diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 101061191e..d43af8a2e3 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -976,37 +976,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float uint32_t int32; } u = {x}; - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - bool flag0 = ~u.int32 & 0x7f800000; - - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bfloat16's mantissa bits are all 0. - bool flag1 = !flag0 && (u.int32 & 0xffff); - - u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even - u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN - return uint16_t(u.int32 >> 16); } @@ -1064,6 +1033,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } +// Declare a template function for bf16 conversion using RTN +template +__host__ __device__ constexpr Y bf16_convert_rtn(X x); + +// Convert fp32 to bf16 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + bool flag0 = ~u.int32 & 0x7f800000; + + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bfloat16's mantissa bits are all 0. + bool flag1 = !flag0 && (u.int32 & 0xffff); + + u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even + u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN + + return uint16_t(u.int32 >> 16); +} + +// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) +{ + float x_fp32 = static_cast(x); + + return bf16_convert_rtn(x_fp32); +} + template struct NumericLimits { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 6728bb1f47..be69f297b2 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -6,6 +6,7 @@ #include #include +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -66,8 +67,26 @@ struct ReferenceGemm : public device::BaseOperator ADataType v_a; BDataType v_b; - arg.a_element_op_(v_a, arg.a_m_k_(m, k)); - arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } v_acc += ck::type_convert(v_a) * ck::type_convert(v_b);