diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp index c635da57f4..9ee63312a3 100644 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp @@ -95,7 +95,7 @@ struct GridwiseReduction_xy_to_x_blockwise const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto dst_global_buf = make_dynamic_buffer( p_dst_global, dst1dDesc.GetElementSpaceSize()); @@ -178,11 +178,11 @@ struct GridwiseReduction_xy_to_x_blockwise if(thread_local_id == 0) { if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { @@ -246,7 +246,7 @@ struct GridwiseReduction_xy_to_x_blockwise const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto dst_global_val_buf = make_dynamic_buffer( p_dst_global, dst1dDesc.GetElementSpaceSize()); auto dst_global_idx_buf = make_dynamic_buffer( @@ -347,11 +347,11 @@ struct GridwiseReduction_xy_to_x_blockwise if(thread_local_id == 0) { if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { @@ -433,10 +433,8 @@ struct GridwiseReduction_xy_to_x_blockwise const auto zeroVal = opReduce::GetReductionZeroVal(); - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); + const auto src_global_val_buf = make_dynamic_buffer( + ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); const auto src_global_idx_buf = make_dynamic_buffer( ws_indices_global, src2dDesc.GetElementSpaceSize()); auto dst_global_val_buf = make_dynamic_buffer( @@ -553,11 +551,11 @@ struct GridwiseReduction_xy_to_x_blockwise if(thread_local_id == 0) { if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp index adfeacc037..1ac24b7eac 100644 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp @@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto dst_global_buf = make_dynamic_buffer( p_dst_global, dst1dDesc.GetElementSpaceSize()); @@ -145,11 +145,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { @@ -207,7 +207,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto dst_global_val_buf = make_dynamic_buffer( p_dst_global, dst1dDesc.GetElementSpaceSize()); auto dst_global_idx_buf = make_dynamic_buffer( @@ -273,11 +273,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { @@ -350,10 +350,8 @@ struct GridwiseReduction_xy_to_x_direct_threadwise const auto zeroVal = opReduce::GetReductionZeroVal(); - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); + const auto src_global_val_buf = make_dynamic_buffer( + ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); const auto src_global_idx_buf = make_dynamic_buffer( ws_indices_global, src2dDesc.GetElementSpaceSize()); auto dst_global_val_buf = make_dynamic_buffer( @@ -436,11 +434,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp index 4136dae75f..402d4e0d02 100644 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp @@ -85,7 +85,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto dst_global_buf = make_dynamic_buffer( p_dst_global, dst1dDesc.GetElementSpaceSize()); @@ -154,11 +154,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(thread_inwarp_id == 0) { if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { @@ -218,7 +218,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto dst_global_val_buf = make_dynamic_buffer( p_dst_global, dst1dDesc.GetElementSpaceSize()); auto dst_global_idx_buf = make_dynamic_buffer( @@ -293,11 +293,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(thread_inwarp_id == 0) { if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { @@ -375,10 +375,8 @@ struct GridwiseReduction_xy_to_x_direct_warpwise const auto zeroVal = opReduce::GetReductionZeroVal(); - const auto src_global_val_buf = - make_dynamic_buffer(ws_values_global, - src2dDesc.GetElementSpaceSize(), - type_convert{}(zeroVal)); + const auto src_global_val_buf = make_dynamic_buffer( + ws_values_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); const auto src_global_idx_buf = make_dynamic_buffer( ws_indices_global, src2dDesc.GetElementSpaceSize()); auto dst_global_val_buf = make_dynamic_buffer( @@ -472,11 +470,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(thread_inwarp_id == 0) { if(!float_equal_one{}(alpha)) - accuValue_buf(I0) *= type_convert{}(alpha); + accuValue_buf(I0) *= type_convert(alpha); StaticBuffer dstValue_buf; - dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + dstValue_buf(I0) = type_convert(accuValue_buf[I0]); if(!float_equal_zero{}(beta)) { diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp index feee2b594a..dda2efa884 100644 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp @@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_multiblock __shared__ compType p_in_block_buffer[BlockBufferSize]; const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto workspace_global_buf = make_dynamic_buffer( ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); @@ -223,7 +223,7 @@ struct GridwiseReduction_xy_to_x_multiblock __shared__ int p_in_block_indices_buffer[BlockBufferSize]; const auto src_global_buf = make_dynamic_buffer( - p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); + p_src_global, src2dDesc.GetElementSpaceSize(), type_convert(zeroVal)); auto workspace_global_val_buf = make_dynamic_buffer( ws_values_global, dst1dDesc.GetLength(I0) * BlkGroupSize); auto workspace_global_idx_buf = make_dynamic_buffer( diff --git a/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp b/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp index 046d3311aa..ff21118d24 100644 --- a/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp +++ b/composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp @@ -64,7 +64,7 @@ struct BlockwiseReduction_2d_block_buffer offset = blockIsOneRow ? buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, thread_local_id)) : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); - compType opData = type_convert{}(block_buffer[offset]); + compType opData = type_convert(block_buffer[offset]); binop::calculate(lAccuData, opData); } @@ -89,10 +89,10 @@ struct BlockwiseReduction_2d_block_buffer ? buffer2dDesc.CalculateOffset(make_tuple(0, thread_local_id + indOffset)) : buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); - compType opData1 = type_convert{}(block_buffer[offset1]); - compType opData2 = type_convert{}(block_buffer[offset2]); + compType opData1 = type_convert(block_buffer[offset1]); + compType opData2 = type_convert(block_buffer[offset2]); binop::calculate(opData1, opData2); - block_buffer(offset1) = type_convert{}(opData1); + block_buffer(offset1) = type_convert(opData1); } __syncthreads(); @@ -100,7 +100,7 @@ struct BlockwiseReduction_2d_block_buffer if(thread_local_id == 0) { - compType tmpVal = type_convert{}(block_buffer[0]); + compType tmpVal = type_convert(block_buffer[0]); binop::calculate(accuData, tmpVal); } @@ -131,13 +131,13 @@ struct BlockwiseReduction_2d_block_buffer index_t offset2 = buffer2dDesc.CalculateOffset( make_tuple(otherDimInd, thread_local_id + indOffset)); - compType currVal1 = type_convert{}(block_buffer[offset1]); - compType currVal2 = type_convert{}(block_buffer[offset2]); + compType currVal1 = type_convert(block_buffer[offset1]); + compType currVal2 = type_convert(block_buffer[offset2]); int currIndex1 = block_indices_buffer[offset1]; int currIndex2 = block_indices_buffer[offset2]; binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - block_buffer(offset1) = type_convert{}(currVal1); + block_buffer(offset1) = type_convert(currVal1); block_indices_buffer(offset1) = currIndex1; } __syncthreads(); @@ -150,7 +150,7 @@ struct BlockwiseReduction_2d_block_buffer { index_t offset = buffer2dDesc.CalculateOffset(make_tuple(otherDimInd, 0)); - compType tmpVal = type_convert{}(block_buffer[offset]); + compType tmpVal = type_convert(block_buffer[offset]); int tmpIndex = block_indices_buffer[offset]; binop::calculate(lAccuData, tmpVal, lAccuIndex, tmpIndex); @@ -166,7 +166,7 @@ struct BlockwiseReduction_2d_block_buffer for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) { offset = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id, otherDimInd)); - compType currVal = type_convert{}(block_buffer[offset]); + compType currVal = type_convert(block_buffer[offset]); int currIndex = block_indices_buffer[offset]; binop::calculate(lAccuData, currVal, lAccuIndex, currIndex); @@ -187,13 +187,13 @@ struct BlockwiseReduction_2d_block_buffer index_t offset2 = buffer2dDesc.CalculateOffset(make_tuple(thread_local_id + indOffset, 0)); - compType currVal1 = type_convert{}(block_buffer[offset1]); - compType currVal2 = type_convert{}(block_buffer[offset2]); + compType currVal1 = type_convert(block_buffer[offset1]); + compType currVal2 = type_convert(block_buffer[offset2]); int currIndex1 = block_indices_buffer[offset1]; int currIndex2 = block_indices_buffer[offset2]; binop::calculate(currVal1, currVal2, currIndex1, currIndex2); - block_buffer(offset1) = type_convert{}(currVal1); + block_buffer(offset1) = type_convert(currVal1); block_indices_buffer(offset1) = currIndex1; } @@ -202,7 +202,7 @@ struct BlockwiseReduction_2d_block_buffer if(thread_local_id == 0) { - compType tmpVal = type_convert{}(block_buffer[0]); + compType tmpVal = type_convert(block_buffer[0]); int tmpIndex = block_indices_buffer[0]; binop::calculate(accuData, tmpVal, accuIndex, tmpIndex); @@ -227,9 +227,9 @@ struct BlockwiseReduction_2d_block_buffer } }; - // Initialize the block-wise indices buffer, the index for each element in the block-wise data - // buffer - // is calculated according to its position in the buffer and the global starting index + // Initialize the block-wise indices buffer, the index for each element in the block-wise + // data buffer is calculated according to its position in the buffer and the global starting + // index template __device__ static void init_buffer_indices(IdxBufferType& block_indices_buffer, int indexStart) { diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 7e3f6b3489..c02e959461 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -196,7 +196,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); dst_vector.template AsType()(i) = - type_convert{}(src_buf[Number{}]); + type_convert(src_buf[Number{}]); }); const bool is_dst_valid = @@ -983,7 +983,7 @@ struct ThreadwiseTensorSliceTransfer_v3 buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); dst_tmp_vector.template AsType()(i) = - type_convert{}(buffer_[Number{}]); + type_convert(buffer_[Number{}]); }); using dst_vector_t = typename decltype(dst_tmp_vector)::type; @@ -1403,7 +1403,7 @@ struct ThreadwiseTensorSliceTransfer_v4 // TODO: if SrcData and DstData are vetor type, then static_cast may not compile static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { dst_tmp_vector.template AsType()(i) = - type_convert{}(src_tmp_vector.template AsType()[i]); + type_convert(src_tmp_vector.template AsType()[i]); }); // copy data from dst_tmp_vector into dst_buf diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp index bbdaa5fa2b..9d996afbb0 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp @@ -351,7 +351,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_vector_desc.CalculateOffset(dst_vector_idx); dst_vector.template AsType()(Number{}) = - type_convert{}(buffer_[Number{}]); + type_convert(buffer_[Number{}]); }); using dst_vector_t = typename decltype(dst_vector)::type; @@ -750,7 +750,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); - dst_buf(Number{}) = type_convert{}( + dst_buf(Number{}) = type_convert( src_vector.template AsType()[Number{}]); }); }); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp index 0a8a385c85..20d0bd1144 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp @@ -248,7 +248,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE static_ford{}([&](auto idx) { // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert{}(src_thread_scratch_[idx]); + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); }); #else // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ @@ -322,7 +322,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 { static_ford{}([&](auto idx) { // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert{}(src_thread_scratch_[idx]); + dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); }); } #endif diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp index 96157bd19d..77b7191907 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/composable_kernel/include/utility/data_type.hpp @@ -927,23 +927,36 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -__host__ __device__ float bf16_to_f32(ushort src_val) +// Convert X to Y +template +__host__ __device__ Y type_convert(X x) +{ + return static_cast(x); +} + +// convert bfp16 to fp32 +template <> +inline __host__ __device__ float type_convert(ushort x) { union { uint32_t int32; float fp32; - } u = {uint32_t(src_val) << 16}; + } u = {uint32_t(x) << 16}; + return u.fp32; } -__host__ __device__ ushort f32_to_bf16(float src_val) +// convert fp32 to bfp16 +template <> +inline __host__ __device__ ushort type_convert(float x) { union { float fp32; uint32_t int32; - } u = {src_val}; + } u = {x}; + if(~u.int32 & 0x7f800000) { // When the exponent bits are not all 1s, then the value is zero, normal, @@ -976,40 +989,14 @@ __host__ __device__ ushort f32_to_bf16(float src_val) // the bloat16's mantissa bits are all 0. u.int32 |= 0x10000; // Preserve signaling NaN } + return uint16_t(u.int32 >> 16); } -// data type conversion -template -struct type_convert -{ - template - __device__ T operator()(X x) const - { - return static_cast(x); - } -}; - -template <> -template <> -__device__ float type_convert::operator()(ushort x) const -{ - return bf16_to_f32(x); -} - -template <> -template <> -__device__ ushort type_convert::operator()(float x) const -{ - return f32_to_bf16(x); -} - // TODO: deprecate this template struct inner_product_with_conversion { - static constexpr auto convert = type_convert(); - template __device__ T operator()(typename vector_type::type a, typename vector_type::type b) const @@ -1020,13 +1007,16 @@ struct inner_product_with_conversion T acc = 0; static_for<0, N, 1>{}([&](auto i) { - acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + acc += type_convert(a_vector.Scalars()[i]) * type_convert(b_vector.Scalars()[i]); }); return acc; } - __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } + __device__ T operator()(float_t a, float_t b) const + { + return type_convert(a) * type_convert(b); + } __device__ T operator()(int8x4_t a, int8x4_t b) const { @@ -1036,7 +1026,8 @@ struct inner_product_with_conversion T acc = 0; static_for<0, 4, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + acc += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); return acc; @@ -1050,7 +1041,8 @@ struct inner_product_with_conversion T acc = 0; static_for<0, 8, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + acc += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); return acc; @@ -1064,7 +1056,8 @@ struct inner_product_with_conversion T acc = 0; static_for<0, 16, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + acc += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); return acc; diff --git a/composable_kernel/include/utility/inner_product.hpp b/composable_kernel/include/utility/inner_product.hpp index 813b559474..0b13986516 100644 --- a/composable_kernel/include/utility/inner_product.hpp +++ b/composable_kernel/include/utility/inner_product.hpp @@ -28,12 +28,6 @@ __device__ void inner_product(const float& a, const float& #endif } -template <> -__device__ void inner_product(const ushort& a, const ushort& b, float& c) -{ - c += bf16_to_f32(a) * bf16_to_f32(b); -} - template <> __device__ void inner_product(const float2_t& a, const float2_t& b, float& c) @@ -90,13 +84,12 @@ __device__ void inner_product(const half2_t& a, const h c = __builtin_amdgcn_sdot2(a, b, c, false); #endif #else - const auto convert = type_convert{}; - const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 2, 1>{}([&](auto i) { - c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } @@ -156,13 +149,12 @@ inner_product(const int8x4_t& a, const int8x4_t& b, c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); #endif #else - const auto convert = type_convert{}; - const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 4, 1>{}([&](auto i) { - c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } diff --git a/composable_kernel/include/utility/reduction_operator.hpp b/composable_kernel/include/utility/reduction_operator.hpp index c0afbec869..15538b9920 100644 --- a/composable_kernel/include/utility/reduction_operator.hpp +++ b/composable_kernel/include/utility/reduction_operator.hpp @@ -165,7 +165,7 @@ struct unary_identic scaler = 1.0f / static_cast(divider); }; - __device__ inline constexpr T operator()(T a) const { return a * type_convert{}(scaler); }; + __device__ inline constexpr T operator()(T a) const { return a * type_convert(scaler); }; float scaler = 1.0f; }; @@ -187,7 +187,7 @@ struct unary_square { a = a * a; - return a * type_convert{}(scaler); + return a * type_convert(scaler); }; float scaler = 1.0f; @@ -210,7 +210,7 @@ struct unary_abs { a = abs(a); - return a * type_convert{}(scaler); + return a * type_convert(scaler); }; float scaler = 1.0f; @@ -249,7 +249,7 @@ struct unary_abs { a = static_cast(__habs(a)); - return a * type_convert{}(scaler); + return a * type_convert(scaler); }; float scaler = 1.0f; diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index d87195e366..f1ae9dc515 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor& in, { if constexpr(is_same::value) { - v += ck::bf16_to_f32(in(n, c, hi, wi)) * - ck::bf16_to_f32(wei(k, c, y, x)); + v += ck::type_convert(in(n, c, hi, wi)) * + ck::type_convert(wei(k, c, y, x)); } else { @@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor& in, if constexpr(is_same::value) { - out(n, k, ho, wo) = f32_to_bf16(v); + out(n, k, ho, wo) = type_convert(v); } else { @@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor& in, { if constexpr(is_same::value) { - v += ck::bf16_to_f32(in(n, hi, wi, c)) * - ck::bf16_to_f32(wei(k, y, x, c)); + v += ck::type_convert(in(n, hi, wi, c)) * + ck::type_convert(wei(k, y, x, c)); } else { @@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor& in, } if constexpr(is_same::value) { - out(n, ho, wo, k) = f32_to_bf16(v); + out(n, ho, wo, k) = ck::type_convert(v); } else { diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp index b5dbedd1d0..010091fe1f 100644 --- a/host/host_tensor/include/host_gemm.hpp +++ b/host/host_tensor/include/host_gemm.hpp @@ -1,162 +1,6 @@ #pragma once #include "host_tensor.hpp" -template <> -void host_gemm(const Tensor& a, - const Tensor& b, - Tensor& c, - const GemmMatrixLayout layout) -{ - if(layout == GemmMatrixLayout::MK_KN_MN) - { - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n)); - } - - c(m, n) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_MN) - { - auto f_mk_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k)); - } - - c(m, n) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_MN) - { - auto f_km_kn_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n)); - } - - c(m, n) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_MN) - { - auto f_km_nk_mn = [&](auto m, auto n) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k)); - } - - c(m, n) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_KN_NM) - { - auto f_mk_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n)); - } - - c(n, m) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::MK_NK_NM) - { - auto f_mk_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k)); - } - - c(n, m) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_KN_NM) - { - auto f_km_kn_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n)); - } - - c(n, m) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else if(layout == GemmMatrixLayout::KM_NK_NM) - { - auto f_km_nk_nm = [&](auto n, auto m) { - const int K = a.mDesc.GetLengths()[0]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k)); - } - - c(n, m) = ck::f32_to_bf16(v); - }; - - make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} - template void host_gemm_mk_kn_mn(const Tensor& a_m_k, const Tensor& b_k_n, diff --git a/host/host_tensor/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp index 352ccccde0..ae30426913 100644 --- a/host/host_tensor/include/host_tensor.hpp +++ b/host/host_tensor/include/host_tensor.hpp @@ -299,53 +299,41 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector s void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); +float bf16_to_f32_(ushort src_val); + template void check_error(const Tensor& ref, const Tensor& result) { float error = 0; float max_diff = -1; float ref_value = 0, result_value = 0; - for(int i = 0; i < ref.mData.size(); ++i) + + if constexpr(std::is_same::value) { - error += std::abs(double(ref.mData[i]) - double(result.mData[i])); - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) + for(int i = 0; i < ref.mData.size(); ++i) { - max_diff = diff; - ref_value = ref.mData[i]; - result_value = result.mData[i]; + error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i])); + float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i])); + if(max_diff < diff) + { + max_diff = diff; + ref_value = bf16_to_f32_(ref.mData[i]); + result_value = bf16_to_f32_(result.mData[i]); + } } } - - std::cout << "error: " << error << std::endl; - std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; -} - -__host__ __device__ float bf16_to_f32(ushort src_val) -{ - union + else { - uint32_t int32; - float fp32; - } u = {uint32_t(src_val) << 16}; - return u.fp32; -} - -template <> -void check_error(const Tensor& ref, const Tensor& result) -{ - float error = 0; - float max_diff = -1; - float ref_value = 0, result_value = 0; - for(int i = 0; i < ref.mData.size(); ++i) - { - error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i])); - float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i])); - if(max_diff < diff) + for(int i = 0; i < ref.mData.size(); ++i) { - max_diff = diff; - ref_value = bf16_to_f32(ref.mData[i]); - result_value = bf16_to_f32(result.mData[i]); + error += std::abs(double(ref.mData[i]) - double(result.mData[i])); + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + max_diff = diff; + ref_value = ref.mData[i]; + result_value = result.mData[i]; + } } } diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp index 7734b7134b..0b979069a6 100644 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -5,15 +5,25 @@ #include "config.hpp" #include "data_type.hpp" +template +struct GeneratorTensor_0 +{ + template + T operator()(Is...) + { + return T{0}; + } +}; + template struct GeneratorTensor_1 { int value = 1; template - float operator()(Is...) + T operator()(Is...) { - return value; + return ck::type_convert(value); } }; @@ -25,7 +35,7 @@ struct GeneratorTensor_1 template ushort operator()(Is...) { - return ck::f32_to_bf16(value); + return ck::type_convert(value); } }; @@ -41,17 +51,6 @@ struct GeneratorTensor_1 } }; -struct GeneratorTensor_0 -{ - int value = 0; - - template - float operator()(Is...) - { - return value; - } -}; - template struct GeneratorTensor_2 { @@ -59,7 +58,7 @@ struct GeneratorTensor_2 int max_value = 1; template - float operator()(Is...) + T operator()(Is...) { return (std::rand() % (max_value - min_value)) + min_value; } @@ -75,7 +74,7 @@ struct GeneratorTensor_2 ushort operator()(Is...) { float tmp = (std::rand() % (max_value - min_value)) + min_value; - return ck::f32_to_bf16(tmp); + return ck::type_convert(tmp); } }; @@ -99,7 +98,7 @@ struct GeneratorTensor_3 T max_value = 1; template - float operator()(Is...) + T operator()(Is...) { float tmp = float(std::rand()) / float(RAND_MAX); @@ -120,7 +119,7 @@ struct GeneratorTensor_3 float fp32_tmp = min_value + tmp * (max_value - min_value); - return ck::f32_to_bf16(fp32_tmp); + return ck::type_convert(fp32_tmp); } }; diff --git a/host/host_tensor/src/host_tensor.cpp b/host/host_tensor/src/host_tensor.cpp index bb4eb62075..4e3cdbdccd 100644 --- a/host/host_tensor/src/host_tensor.cpp +++ b/host/host_tensor/src/host_tensor.cpp @@ -61,3 +61,13 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream LogRange(os, desc.GetStrides(), ", "); os << "}" << std::endl; } + +float bf16_to_f32_(ushort src_val) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(src_val) << 16}; + return u.fp32; +} diff --git a/profiler/include/profile_conv.hpp b/profiler/include/profile_conv.hpp index 755cfddf9d..94fb6373f7 100644 --- a/profiler/include/profile_conv.hpp +++ b/profiler/include/profile_conv.hpp @@ -106,12 +106,12 @@ void profile_conv(int do_verification, { case 0: break; case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } if(do_verification) diff --git a/profiler/include/profile_gemm.hpp b/profiler/include/profile_gemm.hpp index a88468f557..6237588e90 100644 --- a/profiler/include/profile_gemm.hpp +++ b/profiler/include/profile_gemm.hpp @@ -122,12 +122,12 @@ void profile_gemm(int do_verification, { case 0: break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } if(do_verification)