mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-25 07:14:37 +00:00
bf16A_Int8B with fastgelu/bias (#1264)
* changed the copy function to v7r2
* adding multi_abd
* in-progress
* add post-load oob check
* debugging
* adjust instances
* add run_lds
* add elemntwise_op
* replace multi_abd_device with v3
* clean up
* clean
* clean
* Added LDSType
* profiling
* adjust oobcheck
* add missing file
* refactor
* clean
* add examples
[ROCm/composable_kernel commit: 0d0150db20]
This commit is contained in:
@@ -41,7 +41,8 @@ template <typename ThreadGroup,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags>
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers>
|
||||
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs);
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
template <typename DstBuffers>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
|
||||
else
|
||||
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs));
|
||||
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags>;
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -92,15 +92,6 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct Scales
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
y = ck::type_convert<Y>(ck::type_convert<float>(x0) * ck::type_convert<float>(x1));
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -188,6 +179,16 @@ struct Multiply
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x0);
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp * x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
@@ -521,6 +522,71 @@ struct AddFastGelu
|
||||
}
|
||||
};
|
||||
|
||||
// E = MultiplyFastGelu(C + D)
|
||||
struct MultiplyFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d) const
|
||||
{
|
||||
const float x = c * d;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
|
||||
{
|
||||
const half_t x = c * d;
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
|
||||
{
|
||||
const float x0_f = c * d;
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
|
||||
x0_f);
|
||||
|
||||
e = type_convert<half_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = type_convert<float>(c) * type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = c * type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Silu(C + D)
|
||||
struct AddSilu
|
||||
{
|
||||
|
||||
@@ -221,6 +221,15 @@ struct MultiplyAdd
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
|
||||
const float& c,
|
||||
const bhalf_t& d0,
|
||||
const bhalf_t& d1) const
|
||||
{
|
||||
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
@@ -240,6 +249,26 @@ struct MultiplyAdd
|
||||
}
|
||||
};
|
||||
|
||||
struct MultiplyAddFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
|
||||
ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
|
||||
{
|
||||
const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = ck::type_convert<ck::bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
|
||||
@@ -504,6 +504,16 @@ struct FastGelu
|
||||
y = type_convert<half_t>(y_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float y_f;
|
||||
|
||||
this->operator()<float, float>(y_f, x);
|
||||
|
||||
y = type_convert<bhalf_t>(y_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
|
||||
@@ -594,11 +594,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
|
||||
Number<NumATensor>{});
|
||||
|
||||
#if 0
|
||||
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
#endif
|
||||
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
AsDataType,
|
||||
@@ -616,7 +611,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
uniform_sequence_gen_t<NumATensor, false>,
|
||||
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
|
||||
Sequence<true>>{as_grid_desc_ak0_m_ak1,
|
||||
idx_as_block_begin,
|
||||
tie(a_block_desc_ak0_m_ak1),
|
||||
@@ -627,11 +622,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
|
||||
Number<NumBTensor>{});
|
||||
|
||||
#if 0
|
||||
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
#endif
|
||||
|
||||
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
BsDataType,
|
||||
@@ -649,7 +639,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
uniform_sequence_gen_t<NumBTensor, false>,
|
||||
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
|
||||
Sequence<true>>{bs_grid_desc_bk0_n_bk1,
|
||||
idx_bs_block_begin,
|
||||
tie(b_block_desc_bk0_n_bk1),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,8 @@ template <typename SrcDatas,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -139,14 +140,19 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
|
||||
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
|
||||
template <typename SrcBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
|
||||
|
||||
bool oob_val = true;
|
||||
|
||||
// copy data from src_bufs into src_vectors
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
|
||||
@@ -155,9 +161,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
|
||||
is_src_valid);
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
});
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
@@ -218,7 +225,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
|
||||
elm_vectors_tuple_(iAccess) = elm_vectors;
|
||||
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
|
||||
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != src_num_access - 1)
|
||||
@@ -245,17 +253,38 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void TransposeFromElmToDst()
|
||||
#if 1
|
||||
template <index_t ThreadScratchId = 0>
|
||||
__device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
|
||||
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
|
||||
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
|
||||
});
|
||||
|
||||
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t ThreadScratchId = 0>
|
||||
__device__ void
|
||||
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
|
||||
|
||||
using SrcThreadScratch =
|
||||
using ElmThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
SrcScalarPerVector,
|
||||
decltype(GetSrcThreadScratchDescriptor()),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
@@ -263,15 +292,17 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
decltype(GetDstThreadScratchDescriptor()),
|
||||
true>;
|
||||
|
||||
SrcThreadScratch elm_thread_scratch_;
|
||||
ElmThreadScratch elm_thread_scratch_;
|
||||
DstThreadScratch dst_thread_scratch_;
|
||||
|
||||
elm_thread_scratch_.data_ =
|
||||
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_);
|
||||
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
|
||||
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
@@ -338,20 +369,24 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
|
||||
}
|
||||
|
||||
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_);
|
||||
dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
TransposeFromElmToDst();
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}];
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
@@ -578,8 +613,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
|
||||
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_;
|
||||
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
|
||||
using ElmVectorTuple = StaticallyIndexedArray<ElmVectorsType, src_num_access>;
|
||||
using DstVectorTuple = StaticallyIndexedArray<DstVectorsType, dst_num_access>;
|
||||
|
||||
StaticallyIndexedArray<ElmVectorTuple, NumThreadScratch> elm_vectors_tuple_;
|
||||
StaticallyIndexedArray<DstVectorTuple, NumThreadScratch> dst_vectors_tuple_;
|
||||
|
||||
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
|
||||
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
|
||||
|
||||
SrcCoords src_coords_;
|
||||
DstCoords dst_coords_;
|
||||
|
||||
@@ -40,23 +40,10 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y bit_cast(const X& x)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
|
||||
Y y;
|
||||
static_assert(__has_builtin(__builtin_bit_cast), "");
|
||||
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
|
||||
|
||||
// auto t = reinterpret_cast<const Y*>(&x);
|
||||
// y = *t;
|
||||
__builtin_memcpy(&y, &x, sizeof(X));
|
||||
|
||||
return y;
|
||||
#else
|
||||
union AsType
|
||||
{
|
||||
X x;
|
||||
Y y;
|
||||
};
|
||||
|
||||
return AsType{x}.y;
|
||||
#endif
|
||||
return __builtin_bit_cast(Y, x);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user