mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Optimize bf16 conversion (#664)
* Add TypeConvert class and start refactoring * Refactor TypeConvert as a struct * Get back to template functions type_convert * Add a type_convert_bf16_rtn, set rtz as default * Clean up * Add UnaryConvertPrecision struct for high-precision workloads * Format * Update type_convert to UnaryConvert on threadwise level * Update UnaryConvertPrecision * Format * Fix chmod * Add a flag to pick converion method * Format * Remove the added flag * Merge elementwise op with type conversion * Move type_convert to elemwise op, update the op * Update type_convert_precision -> bf16_convert_rtn * Clean up * Update comments * Update the CK_WORKAROUND_DENORM_FIX flag handling * Update the unneeded op to work but warn user * Remove the message * Use a PassThrough instead of ConvertBF16RTN to calcaulate reference * Format * Add missing include
This commit is contained in:
@@ -56,6 +56,12 @@ struct PassThrough
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
@@ -86,6 +92,23 @@ struct UnaryConvert
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertBF16RTN
|
||||
{
|
||||
// convert to bf16 using round to nearest (rtn)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(is_same<Y, bhalf_t>::value, "Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = bf16_convert_rtn<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Scale
|
||||
{
|
||||
__host__ __device__ Scale(float scale) : scale_(scale) {}
|
||||
|
||||
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// ABDataTypeAdjusted -> ABDataType throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using ABDataTypeAdjusted =
|
||||
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
|
||||
#else
|
||||
|
||||
@@ -266,7 +266,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
|
||||
@@ -136,7 +136,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
|
||||
@@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
|
||||
|
||||
// apply SrcElementwiseOperation on src_vector_container
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
SrcData src_v;
|
||||
|
||||
src_element_op_(src_v, src_vector_container.template AsType<SrcData>()[i]);
|
||||
|
||||
src_vector_container.template AsType<SrcData>()(i) = src_v;
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<src_vector_t>(
|
||||
@@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
// TODO type_convert is not used yet!!!!!
|
||||
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
@@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
// TODO type_convert is not used yet!!!!!
|
||||
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) =
|
||||
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
|
||||
});
|
||||
}
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
DstData dst_v;
|
||||
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
|
||||
dst_thread_scratch_(idx) = dst_v;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user