mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Overhaul vector_type and use real vector for int8x4_t instead of aliasing from int32_t (#29)
* overhaul vector_type, make int8x4_t real vector instead of aliasing from int32_t
This commit is contained in:
@@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
using vector_t = typename vector_type<Data, DataPerAccess>::type;
|
||||
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
|
||||
|
||||
static_for<0, NSliceRow, 1>{}([&](auto i) {
|
||||
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
|
||||
|
||||
@@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
}();
|
||||
|
||||
// copy data
|
||||
vector_type<DstData, DstScalarPerVector> dst_vector;
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
|
||||
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::type;
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset =
|
||||
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
|
||||
i * dst_scalar_step_in_vector);
|
||||
|
||||
dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]);
|
||||
dst_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>{}(p_src[Number<src_offset>{}]);
|
||||
});
|
||||
|
||||
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
@@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_buffer_store_v2<DstData, DstScalarPerVector>(
|
||||
dst_vector.Vector(),
|
||||
dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
|
||||
p_dst,
|
||||
dst_slice_origin_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
@@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
if(is_dst_valid)
|
||||
{
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
|
||||
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
if(is_dst_valid)
|
||||
{
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector();
|
||||
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
// copy data
|
||||
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
|
||||
|
||||
vector_type<SrcData, SrcScalarPerVector> src_vector;
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
|
||||
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_slice_origin_coord_);
|
||||
@@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
if constexpr(SrcAddressSpace == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
|
||||
p_src,
|
||||
src_slice_origin_coord_.GetOffset(),
|
||||
is_src_valid,
|
||||
src_desc.GetElementSpaceSize());
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
|
||||
p_src,
|
||||
src_slice_origin_coord_.GetOffset(),
|
||||
is_src_valid,
|
||||
src_desc.GetElementSpaceSize());
|
||||
#else
|
||||
src_vector.Vector() = is_src_valid
|
||||
? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vector.Vector() = is_src_valid
|
||||
? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
}
|
||||
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
@@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
p_dst[Number<dst_offset>{}] = src_vector.Scalars()[i];
|
||||
p_dst[Number<dst_offset>{}] = src_vector.template AsType<SrcData>()[i];
|
||||
});
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
@@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
}();
|
||||
|
||||
// copy data
|
||||
vector_type<SrcData, SrcScalarPerVector> src_vector;
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
|
||||
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type;
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_slice_origin_coord_);
|
||||
@@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
if constexpr(SrcAddressSpace == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
|
||||
p_src,
|
||||
src_slice_origin_coord_.GetOffset(),
|
||||
is_src_valid,
|
||||
src_desc.GetElementSpaceSize());
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
|
||||
p_src,
|
||||
src_slice_origin_coord_.GetOffset(),
|
||||
is_src_valid,
|
||||
src_desc.GetElementSpaceSize());
|
||||
#else
|
||||
src_vector.Vector() = is_src_valid
|
||||
? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vector.Vector() = is_src_valid
|
||||
? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
}
|
||||
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
buffer_(Number<buffer_offset>{}) = src_vector.Scalars()[i];
|
||||
buffer_(Number<buffer_offset>{}) = src_vector.template AsType<SrcData>()[i];
|
||||
});
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
@@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
DstInMemOp == InMemoryDataOperation::Set,
|
||||
"wrong! hardcoded for ds_write");
|
||||
|
||||
vector_type<DstData, DstScalarPerVector> dst_vector;
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
|
||||
|
||||
dst_vector.Scalars()(i) = buffer_[Number<buffer_offset>{}];
|
||||
dst_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}];
|
||||
});
|
||||
|
||||
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::type;
|
||||
using DstVectorType =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
|
||||
dst_vector.Vector();
|
||||
dst_vector.template AsType<DstVectorType>()[Number<0>{}];
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
|
||||
@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
using vector_t = typename vector_type<Data, DataPerAccess>::type;
|
||||
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
|
||||
|
||||
static_for<0, NSliceRow, 1>{}([&](auto i) {
|
||||
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
|
||||
|
||||
@@ -6,6 +6,17 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
union BufferResource
|
||||
{
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
int32x4_t data;
|
||||
T* address[2];
|
||||
int32_t range[4];
|
||||
int32_t config[4];
|
||||
};
|
||||
|
||||
__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
|
||||
@@ -6,27 +6,27 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
union BufferResource
|
||||
union BufferResource_v2
|
||||
{
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
int32x4_t data;
|
||||
T* address[2];
|
||||
int32_t range[4];
|
||||
int32_t config[4];
|
||||
StaticallyIndexedArray<T*, 2> address;
|
||||
StaticallyIndexedArray<int32_t, 4> range;
|
||||
StaticallyIndexedArray<int32_t, 4> config;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
|
||||
{
|
||||
BufferResource<T> wave_buffer_resource;
|
||||
BufferResource_v2<T> wave_buffer_resource;
|
||||
|
||||
// wavewise base address (64 bit)
|
||||
wave_buffer_resource.address[0] = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
// wavewise range (32 bit)
|
||||
wave_buffer_resource.range[2] = data_space_size * sizeof(T);
|
||||
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
|
||||
// wavewise setting (32 bit)
|
||||
wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return wave_buffer_resource.data;
|
||||
}
|
||||
@@ -37,6 +37,19 @@ __llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
||||
|
||||
__device__ int8x2_t
|
||||
__llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
||||
|
||||
__device__ int8x4_t
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
__device__ int16_t
|
||||
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
@@ -105,6 +118,20 @@ __llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
@@ -182,15 +209,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, half2_t>::value && (N == 1)) ||
|
||||
(is_same<T, half4_t>::value && (N == 1)) ||
|
||||
(is_same<T, half8_t>::value && (N == 1)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32x2_t>::value && (N == 1)) ||
|
||||
(is_same<T, int32x4_t>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
@@ -213,16 +237,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
{
|
||||
vector_type<float, 8> tmp;
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
tmp.AsType<float4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) =
|
||||
tmp.AsType<float4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
0);
|
||||
|
||||
return tmp.Vector();
|
||||
return tmp.AsType<float8_t>()(Number<0>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
@@ -242,39 +266,20 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half2_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half4_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<half_t, 8> tmp;
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
tmp.AsType<half4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) =
|
||||
tmp.AsType<half4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
|
||||
return tmp.Vector();
|
||||
return tmp.AsType<half8_t>()(Number<0>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
@@ -298,32 +303,103 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
{
|
||||
vector_type<int32_t, 8> tmp;
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) =
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
0);
|
||||
return tmp.AsType<int32x8_t>()(Number<0>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i8(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
return __llvm_amdgcn_raw_buffer_load_i8x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
int16_t tmp = __llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return tmp.Vector();
|
||||
return as_type<int8x2_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32x2_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
return __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
int32_t tmp = __llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x4_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32x4_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
vector_type<int8_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<int8x8_t>()(Number<0>{});
|
||||
#else
|
||||
int32x2_t tmp = __llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x8_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
vector_type<int8_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<2>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<3>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<int8x16_t>()(Number<0>{});
|
||||
#else
|
||||
int32x4_t tmp = __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x16_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -407,23 +483,39 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
__llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
__llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
|
||||
__llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -431,7 +523,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
|
||||
__llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -468,13 +560,13 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
{
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}],
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}],
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
@@ -488,26 +580,29 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
// 2) p_src_wave to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wave,
|
||||
index_t src_thread_data_offset,
|
||||
bool src_thread_data_valid,
|
||||
index_t src_element_space)
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_v2(const T* p_src_wave,
|
||||
index_t src_thread_data_offset,
|
||||
bool src_thread_data_valid,
|
||||
index_t src_element_space)
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
|
||||
|
||||
return amd_buffer_load_impl_v2<T, N>(
|
||||
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#else
|
||||
using vector_t = typename vector_type<T, N>::type;
|
||||
|
||||
vector_t tmp =
|
||||
amd_buffer_load_impl_v2<T, N>(src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_data_valid ? tmp : vector_t(0);
|
||||
#endif
|
||||
@@ -518,26 +613,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_v2(const T* p_src_wa
|
||||
// 2) p_dst_wave to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_store_v2(const typename vector_type<T, N>::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_data_offset,
|
||||
const bool dst_thread_data_valid,
|
||||
const index_t dst_element_space)
|
||||
__device__ void
|
||||
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_data_offset,
|
||||
const bool dst_thread_data_valid,
|
||||
const index_t dst_element_space)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
|
||||
|
||||
amd_buffer_store_impl_v2<T, N>(
|
||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_data_valid)
|
||||
{
|
||||
amd_buffer_store_impl_v2<T, N>(
|
||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
|
||||
__device__ void
|
||||
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
@@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
@@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
|
||||
float& c3)
|
||||
{
|
||||
|
||||
// TODO remove pointer casting
|
||||
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
|
||||
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
|
||||
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
|
||||
@@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
|
||||
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
|
||||
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
|
||||
@@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
|
||||
v_dot4_i32_i8 %1, %2, %4, %1\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
: "v"(as_type<int32_t>(a)),
|
||||
"v"(as_type<int32_t>(b0)),
|
||||
"v"(as_type<int32_t>(b1)),
|
||||
"0"(c0),
|
||||
"1"(c1));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
|
||||
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
|
||||
v_dot4_i32_i8 %3, %4, %8, %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
||||
: "v"(as_type<int32_t>(a)),
|
||||
"v"(as_type<int32_t>(b0)),
|
||||
"v"(as_type<int32_t>(b1)),
|
||||
"v"(as_type<int32_t>(b2)),
|
||||
"v"(as_type<int32_t>(b3)),
|
||||
"0"(c0),
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false);
|
||||
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false);
|
||||
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false);
|
||||
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
|
||||
c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
|
||||
c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
|
||||
int32_t& c2,
|
||||
int32_t& c3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a);
|
||||
const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0);
|
||||
const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1);
|
||||
const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2);
|
||||
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
|
||||
p_b0_int8x4_t[0],
|
||||
p_b1_int8x4_t[0],
|
||||
p_b2_int8x4_t[0],
|
||||
p_b3_int8x4_t[0],
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x4_t[1],
|
||||
p_b0_int8x4_t[1],
|
||||
p_b1_int8x4_t[1],
|
||||
p_b2_int8x4_t[1],
|
||||
p_b3_int8x4_t[1],
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
@@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
|
||||
int32_t& c3)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a);
|
||||
const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0);
|
||||
const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1);
|
||||
const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2);
|
||||
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x8_t[0],
|
||||
p_b0_int8x8_t[0],
|
||||
p_b1_int8x8_t[0],
|
||||
p_b2_int8x8_t[0],
|
||||
p_b3_int8x8_t[0],
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x8_t[1],
|
||||
p_b0_int8x8_t[1],
|
||||
p_b1_int8x8_t[1],
|
||||
p_b2_int8x8_t[1],
|
||||
p_b3_int8x8_t[1],
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
|
||||
// GPU ID
|
||||
#if 0
|
||||
#if 1
|
||||
#define CK_AMD_GPU_GFX906 1
|
||||
#elif 0
|
||||
#define CK_AMD_GPU_GFX908 1
|
||||
#elif 1
|
||||
#elif 0
|
||||
#define CK_AMD_GPU_GFX1030 1
|
||||
#endif
|
||||
|
||||
@@ -88,7 +88,7 @@
|
||||
|
||||
// experimental implementation
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
@@ -142,6 +142,11 @@
|
||||
#define CK_WORKAROUND_SWDEV_275126 1
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when using buffer load/store for i8
|
||||
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
|
||||
@@ -3,13 +3,106 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
using half_t = _Float16;
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
struct vector_type;
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
|
||||
// vectors"
|
||||
template <typename T, index_t V, index_t N>
|
||||
struct vector_type<T __attribute__((ext_vector_type(V))), N>;
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
|
||||
// vectors"
|
||||
template <typename T, index_t V, index_t N>
|
||||
struct vector_type<vector_type<T, V>, N>;
|
||||
|
||||
// vector_type_maker
|
||||
// This is the right way to handle "vector of vectors": making a bigger vector instead
|
||||
template <typename T, index_t N>
|
||||
struct vector_type_maker
|
||||
{
|
||||
using type = vector_type<T, N>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<vector_type<T, N1>, N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<vector_type<T, N>>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
//
|
||||
template <>
|
||||
struct scalar_type<float>
|
||||
{
|
||||
using type = float;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<half_t>
|
||||
{
|
||||
using type = half_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<ushort>
|
||||
{
|
||||
using type = ushort;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int32_t>
|
||||
{
|
||||
using type = int32_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int8_t>
|
||||
{
|
||||
using type = int8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
//
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
{
|
||||
using type = T;
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
|
||||
union
|
||||
{
|
||||
@@ -21,19 +114,21 @@ struct vector_type<T, 1>
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 1; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value, "wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d1_; }
|
||||
return data_.d1x1_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d1_; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value, "wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x1_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x1_; }
|
||||
return data_.d1x1_;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -55,23 +150,35 @@ struct vector_type<T, 2>
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 2; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d2_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d2_; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -95,27 +202,45 @@ struct vector_type<T, 4>
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 4; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d4_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d4_; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -141,31 +266,55 @@ struct vector_type<T, 8>
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 8; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d8_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d8_; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -193,35 +342,65 @@ struct vector_type<T, 16>
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 16; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x1_;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"wrong!");
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// fp32
|
||||
@@ -230,7 +409,6 @@ using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
|
||||
// fp16
|
||||
using half_t = _Float16;
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
@@ -246,197 +424,8 @@ using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
|
||||
template <>
|
||||
struct vector_type<int8_t, 2>
|
||||
{
|
||||
using d1_t = int8_t;
|
||||
typedef int16_t d2_t;
|
||||
|
||||
using type = d2_t;
|
||||
|
||||
union
|
||||
{
|
||||
d2_t d2_;
|
||||
StaticallyIndexedArray<d1_t, 2> d1x2_;
|
||||
StaticallyIndexedArray<d2_t, 1> d2x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 2; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<int8_t, 4>
|
||||
{
|
||||
using d1_t = int8_t;
|
||||
typedef int16_t d2_t;
|
||||
typedef int32_t d4_t;
|
||||
|
||||
using type = d4_t;
|
||||
|
||||
union
|
||||
{
|
||||
d4_t d4_;
|
||||
StaticallyIndexedArray<d1_t, 4> d1x4_;
|
||||
StaticallyIndexedArray<d2_t, 2> d2x2_;
|
||||
StaticallyIndexedArray<d4_t, 1> d4x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 4; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<int8_t, 8>
|
||||
{
|
||||
using d1_t = int8_t;
|
||||
typedef int16_t d2_t;
|
||||
typedef int32_t d4_t;
|
||||
typedef int32x2_t d8_t;
|
||||
|
||||
using type = d8_t;
|
||||
|
||||
union
|
||||
{
|
||||
d8_t d8_;
|
||||
StaticallyIndexedArray<d1_t, 8> d1x8_;
|
||||
StaticallyIndexedArray<d2_t, 4> d2x4_;
|
||||
StaticallyIndexedArray<d4_t, 2> d4x2_;
|
||||
StaticallyIndexedArray<d8_t, 1> d8x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 8; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<int8_t, 16>
|
||||
{
|
||||
using d1_t = int8_t;
|
||||
typedef int16_t d2_t;
|
||||
typedef int32_t d4_t;
|
||||
typedef int32x2_t d8_t;
|
||||
typedef int32x4_t d16_t;
|
||||
|
||||
using type = d16_t;
|
||||
|
||||
union
|
||||
{
|
||||
d16_t d16_;
|
||||
StaticallyIndexedArray<d1_t, 16> d1x16_;
|
||||
StaticallyIndexedArray<d2_t, 8> d2x8_;
|
||||
StaticallyIndexedArray<d4_t, 4> d4x4_;
|
||||
StaticallyIndexedArray<d8_t, 2> d8x2_;
|
||||
StaticallyIndexedArray<d8_t, 1> d16x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 16; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
|
||||
};
|
||||
|
||||
// i8
|
||||
// hack for int8x4_t, because compiler does not have native support for int8x4_t
|
||||
// int8x4_t is defined as int32_t
|
||||
using int8x2_t = typename vector_type<int8_t, 2>::type;
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x8_t = typename vector_type<int8_t, 8>::type;
|
||||
using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
@@ -489,8 +478,6 @@ struct inner_product_with_conversion
|
||||
|
||||
__device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
|
||||
|
||||
// hack for int8x4_t, because compiler does not have native support for int8x4_t
|
||||
// int8x4_t is defined as int32_t
|
||||
__device__ T operator()(int8x4_t a, int8x4_t b) const
|
||||
{
|
||||
const vector_type<int8_t, 4> a_vector{a};
|
||||
@@ -499,7 +486,7 @@ struct inner_product_with_conversion
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
|
||||
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
@@ -513,7 +500,7 @@ struct inner_product_with_conversion
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
|
||||
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
@@ -527,7 +514,7 @@ struct inner_product_with_conversion
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 16, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
|
||||
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
|
||||
@@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// run-time variables
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
|
||||
@@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// run-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
|
||||
@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
@@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
|
||||
@@ -83,10 +83,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
|
||||
|
||||
@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
@@ -62,9 +62,9 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
@@ -150,7 +150,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
|
||||
print_array("ConvStrides", to_multi_index(ConvStrides{}));
|
||||
print_array("ConvDilations", to_multi_index(ConvDilations{}));
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using acc_data_t = float;
|
||||
@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
|
||||
Reference in New Issue
Block a user