Optimize bf16 conversion (#664)

* Add TypeConvert class and start refactoring

* Refactor TypeConvert as a struct

* Get back to template functions type_convert

* Add a type_convert_bf16_rtn, set rtz as default

* Clean up

* Add UnaryConvertPrecision struct for high-precision workloads

* Format

* Update type_convert to UnaryConvert on threadwise level

* Update UnaryConvertPrecision

* Format

* Fix chmod

* Add a flag to pick converion method

* Format

* Remove the added flag

* Merge elementwise op with type conversion

* Move type_convert to elemwise op, update the op

* Update type_convert_precision -> bf16_convert_rtn

* Clean up

* Update comments

* Update the CK_WORKAROUND_DENORM_FIX flag handling

* Update the unneeded op to work but warn user

* Remove the message

* Use a PassThrough instead of ConvertBF16RTN to calcaulate reference

* Format

* Add missing include
This commit is contained in:
Rostyslav Geyyer
2023-05-04 10:25:47 -05:00
committed by GitHub
parent b8635a25b2
commit b076a02ad2
8 changed files with 116 additions and 56 deletions

View File

@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
@@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
SrcData src_v;
src_element_op_(src_v, src_vector_container.template AsType<SrcData>()[i]);
src_vector_container.template AsType<SrcData>()(i) = src_v;
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
@@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
// TODO type_convert is not used yet!!!!!
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
@@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number<num_dst_vector>{});
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
});
}
static_ford<SliceLengths>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
DstData dst_v;
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
});
#endif
}