mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Use __builtin_memcpy to implement bit_cast and for accessing vector from pointer of scalars (#53)
* reworking vector_type
* use __builtin_memcpy for bit_cast and vector access of scalar pointer
* clean up
[ROCm/composable_kernel commit: 64350affc5]
This commit is contained in:
@@ -268,14 +268,14 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<double>(tmp);
|
||||
return bit_cast<double>(tmp);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<double2_t>(tmp);
|
||||
return bit_cast<double2_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
@@ -289,8 +289,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
0);
|
||||
vector_type<double, 4> tmp;
|
||||
|
||||
tmp.AsType<double2_t>()(Number<0>{}) = as_type<double2_t>(f32_0);
|
||||
tmp.AsType<double2_t>()(Number<1>{}) = as_type<double2_t>(f32_1);
|
||||
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
|
||||
tmp.AsType<double2_t>()(Number<1>{}) = bit_cast<double2_t>(f32_1);
|
||||
|
||||
return tmp.AsType<double4_t>()(Number<0>{});
|
||||
}
|
||||
@@ -351,7 +351,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<half8_t>(tmp);
|
||||
return bit_cast<half8_t>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, ushort>::value)
|
||||
@@ -376,7 +376,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<ushort8_t>(tmp);
|
||||
return bit_cast<ushort8_t>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
@@ -427,7 +427,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x2_t>(tmp);
|
||||
return bit_cast<int8x2_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
@@ -439,7 +439,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x4_t>(tmp);
|
||||
return bit_cast<int8x4_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
@@ -461,7 +461,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x8_t>(tmp);
|
||||
return bit_cast<int8x8_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
@@ -495,7 +495,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x16_t>(tmp);
|
||||
return bit_cast<int8x16_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -521,7 +521,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
// use fp32 store to mimic fp64 store
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(as_type<float2_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<float2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -529,7 +529,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -606,7 +606,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -703,7 +703,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -719,7 +719,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -728,7 +728,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -736,7 +736,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
|
||||
@@ -211,14 +211,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
|
||||
v_dot4_i32_i8 %1, %2, %4, %1\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(as_type<int32_t>(a)),
|
||||
"v"(as_type<int32_t>(b0)),
|
||||
"v"(as_type<int32_t>(b1)),
|
||||
: "v"(bit_cast<int32_t>(a)),
|
||||
"v"(bit_cast<int32_t>(b0)),
|
||||
"v"(bit_cast<int32_t>(b1)),
|
||||
"0"(c0),
|
||||
"1"(c1));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
|
||||
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -244,20 +244,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
|
||||
v_dot4_i32_i8 %3, %4, %8, %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(as_type<int32_t>(a)),
|
||||
"v"(as_type<int32_t>(b0)),
|
||||
"v"(as_type<int32_t>(b1)),
|
||||
"v"(as_type<int32_t>(b2)),
|
||||
"v"(as_type<int32_t>(b3)),
|
||||
: "v"(bit_cast<int32_t>(a)),
|
||||
"v"(bit_cast<int32_t>(b0)),
|
||||
"v"(bit_cast<int32_t>(b1)),
|
||||
"v"(bit_cast<int32_t>(b2)),
|
||||
"v"(bit_cast<int32_t>(b3)),
|
||||
"0"(c0),
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
|
||||
c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
|
||||
c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
|
||||
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
||||
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
|
||||
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -340,8 +340,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type<int>(reg_a),
|
||||
as_type<int>(reg_b),
|
||||
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -359,8 +359,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type<int>(reg_a),
|
||||
as_type<int>(reg_b),
|
||||
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
|
||||
@@ -99,7 +99,19 @@
|
||||
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
|
||||
|
||||
// merge transformation use magic number division
|
||||
#ifndef CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
|
||||
#endif
|
||||
|
||||
// use __builtin_memcpy instead of pointer cast to access a vector from pointer of scalar
|
||||
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
|
||||
#endif
|
||||
|
||||
// use __builtin_memcpy instead of union to do bit_cast
|
||||
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
|
||||
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
|
||||
#endif
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
@@ -119,7 +131,7 @@
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when using buffer load/store for i8
|
||||
// workaround for compiler gnerating inefficient ds_write instructions
|
||||
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
@@ -1081,11 +1081,11 @@ struct NumericLimits<half_t>
|
||||
static constexpr unsigned short binary_max = 0x7BFF;
|
||||
static constexpr unsigned short binary_lowest = 0xFBFF;
|
||||
|
||||
__host__ __device__ static constexpr half_t Min() { return as_type<half_t>(binary_min); }
|
||||
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Max() { return as_type<half_t>(binary_max); }
|
||||
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Lowest() { return as_type<half_t>(binary_lowest); }
|
||||
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -83,12 +83,28 @@ struct DynamicBuffer
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
|
||||
return is_valid_element ? tmp : X{0};
|
||||
#else
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
|
||||
return is_valid_element ? tmp : X{invalid_element_value_};
|
||||
#else
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
|
||||
: X{invalid_element_value_};
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -117,7 +133,13 @@ struct DynamicBuffer
|
||||
#else
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -126,7 +148,13 @@ struct DynamicBuffer
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#endif
|
||||
#else
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
|
||||
// inefficient
|
||||
@@ -201,7 +229,13 @@ struct DynamicBuffer
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -210,7 +244,13 @@ struct DynamicBuffer
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,9 +144,9 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
|
||||
#endif
|
||||
#else
|
||||
const vector_type<int8_t, 4> a_vector{a};
|
||||
|
||||
@@ -125,7 +125,7 @@ struct MagicDivision
|
||||
__host__ __device__ static constexpr int32_t
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
|
||||
@@ -54,5 +54,49 @@ __host__ __device__ constexpr auto make_statically_indexed_array()
|
||||
return StaticallyIndexedArray<X, 0>();
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct StaticallyIndexedArray_v2
|
||||
{
|
||||
__host__ __device__ constexpr StaticallyIndexedArray_v2() = default;
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return N; }
|
||||
|
||||
// read access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& At(Number<I>) const
|
||||
{
|
||||
static_assert(I < N, "wrong! out of range");
|
||||
|
||||
return data_[I];
|
||||
}
|
||||
|
||||
// write access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& At(Number<I>)
|
||||
{
|
||||
static_assert(I < N, "wrong! out of range");
|
||||
|
||||
return data_[I];
|
||||
}
|
||||
|
||||
// read access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
// write access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
T data_[N];
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -32,8 +32,15 @@ template <typename T>
|
||||
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
|
||||
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y as_type(X x)
|
||||
__host__ __device__ constexpr Y bit_cast(const X& x)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
|
||||
Y y;
|
||||
|
||||
__builtin_memcpy(&y, &x, sizeof(X));
|
||||
|
||||
return y;
|
||||
#else
|
||||
union AsType
|
||||
{
|
||||
X x;
|
||||
@@ -41,6 +48,7 @@ __host__ __device__ constexpr Y as_type(X x)
|
||||
};
|
||||
|
||||
return AsType{x}.y;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_base.hpp"
|
||||
@@ -139,12 +138,12 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
|
||||
@@ -258,7 +258,7 @@ int main(int argc, char* argv[])
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
#elif 0
|
||||
using in_data_t = ushort;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = ushort;
|
||||
|
||||
Reference in New Issue
Block a user