diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp index 33d770092c..7da08d6ef4 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp @@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - using vector_t = typename vector_type::type; + using vector_t = typename vector_type_maker::type::type; static_for<0, NSliceRow, 1>{}([&](auto i) { static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp index 9e2f0b472f..4f9ecd8b54 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp @@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 }(); // copy data - vector_type dst_vector; + typename vector_type_maker::type dst_vector; - using dst_vector_t = typename vector_type::type; + using dst_vector_t = + typename vector_type_maker::type::type; static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t src_offset = src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + i * dst_scalar_step_in_vector); - dst_vector.Scalars()(i) = type_convert{}(p_src[Number{}]); + dst_vector.template AsType()(i) = + type_convert{}(p_src[Number{}]); }); const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( @@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 { #if CK_USE_AMD_BUFFER_ADDRESSING amd_buffer_store_v2( - dst_vector.Vector(), + dst_vector.template AsType()(Number<0>{}), p_dst, dst_slice_origin_coord_.GetOffset(), is_dst_valid, @@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 if(is_dst_valid) { *reinterpret_cast( - &(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); + &(p_dst[dst_slice_origin_coord_.GetOffset()])) = + dst_vector.template AsType()[Number<0>{}]; } #endif } @@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 if(is_dst_valid) { *reinterpret_cast( - &(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); + &(p_dst[dst_slice_origin_coord_.GetOffset()])) = + dst_vector.template AsType()[Number<0>{}]; } } @@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 // copy data static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst"); - vector_type src_vector; + typename vector_type_maker::type src_vector; - using src_vector_t = typename vector_type::type; + using src_vector_t = + typename vector_type_maker::type::type; const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc, src_slice_origin_coord_); @@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 if constexpr(SrcAddressSpace == AddressSpace::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING - src_vector.Vector() = amd_buffer_load_v2( - p_src, - src_slice_origin_coord_.GetOffset(), - is_src_valid, - src_desc.GetElementSpaceSize()); + src_vector.template AsType()(Number<0>{}) = + amd_buffer_load_v2( + p_src, + src_slice_origin_coord_.GetOffset(), + is_src_valid, + src_desc.GetElementSpaceSize()); #else - src_vector.Vector() = is_src_valid - ? *reinterpret_cast( - &p_src[src_slice_origin_coord_.GetOffset()]) - : src_vector_t{0}; + src_vector.template AsType()(Number<0>{}) = + is_src_valid ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; #endif } else { - src_vector.Vector() = is_src_valid - ? *reinterpret_cast( - &p_src[src_slice_origin_coord_.GetOffset()]) - : src_vector_t{0}; + src_vector.template AsType()(Number<0>{}) = + is_src_valid ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; } static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { @@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + i * src_scalar_step_in_vector); - p_dst[Number{}] = src_vector.Scalars()[i]; + p_dst[Number{}] = src_vector.template AsType()[i]; }); constexpr auto move_on_dim = [&]() constexpr @@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 }(); // copy data - vector_type src_vector; + typename vector_type_maker::type src_vector; - using src_vector_t = typename vector_type::type; + using src_vector_t = + typename vector_type_maker::type::type; const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc, src_slice_origin_coord_); @@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 if constexpr(SrcAddressSpace == AddressSpace::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING - src_vector.Vector() = amd_buffer_load_v2( - p_src, - src_slice_origin_coord_.GetOffset(), - is_src_valid, - src_desc.GetElementSpaceSize()); + src_vector.template AsType()(Number<0>{}) = + amd_buffer_load_v2( + p_src, + src_slice_origin_coord_.GetOffset(), + is_src_valid, + src_desc.GetElementSpaceSize()); #else - src_vector.Vector() = is_src_valid - ? *reinterpret_cast( - &p_src[src_slice_origin_coord_.GetOffset()]) - : src_vector_t{0}; + src_vector.template AsType()(Number<0>{}) = + is_src_valid ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; #endif } else { - src_vector.Vector() = is_src_valid - ? *reinterpret_cast( - &p_src[src_slice_origin_coord_.GetOffset()]) - : src_vector_t{0}; + src_vector.template AsType()(Number<0>{}) = + is_src_valid ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; } static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { constexpr index_t buffer_offset = buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); - buffer_(Number{}) = src_vector.Scalars()[i]; + buffer_(Number{}) = src_vector.template AsType()[i]; }); constexpr auto move_on_dim = [&]() constexpr @@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 DstInMemOp == InMemoryDataOperation::Set, "wrong! hardcoded for ds_write"); - vector_type dst_vector; + typename vector_type_maker::type dst_vector; static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t buffer_offset = buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); - dst_vector.Scalars()(i) = buffer_[Number{}]; + dst_vector.template AsType()(i) = buffer_[Number{}]; }); - using DstVectorType = typename vector_type::type; + using DstVectorType = + typename vector_type_maker::type::type; *reinterpret_cast(p_dst + dst_slice_origin_coord_.GetOffset()) = - dst_vector.Vector(); + dst_vector.template AsType()[Number<0>{}]; constexpr auto move_on_dim = [&]() constexpr { diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp index 1af88e5cbb..868f205630 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp @@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - using vector_t = typename vector_type::type; + using vector_t = typename vector_type_maker::type::type; static_for<0, NSliceRow, 1>{}([&](auto i) { static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index b8630c464f..380a14003d 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -6,6 +6,17 @@ namespace ck { +template +union BufferResource +{ + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t data; + T* address[2]; + int32_t range[4]; + int32_t config[4]; +}; + __device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc, index_t vindex, index_t offset, diff --git a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp index ff8d57993f..f5a684b994 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp @@ -6,27 +6,27 @@ namespace ck { template -union BufferResource +union BufferResource_v2 { // 128 bit SGPRs to supply buffer resource in buffer instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions int32x4_t data; - T* address[2]; - int32_t range[4]; - int32_t config[4]; + StaticallyIndexedArray address; + StaticallyIndexedArray range; + StaticallyIndexedArray config; }; template __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) { - BufferResource wave_buffer_resource; + BufferResource_v2 wave_buffer_resource; // wavewise base address (64 bit) - wave_buffer_resource.address[0] = const_cast*>(p_wave); + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); // wavewise range (32 bit) - wave_buffer_resource.range[2] = data_space_size * sizeof(T); + wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T); // wavewise setting (32 bit) - wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; return wave_buffer_resource.data; } @@ -37,6 +37,19 @@ __llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); + +__device__ int8x2_t +__llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); + +__device__ int8x4_t +__llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); + __device__ int16_t __llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, @@ -105,6 +118,20 @@ __llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); +__device__ void +__llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); + __device__ void __llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, int32x4_t rsrc, @@ -182,15 +209,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert((is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1)) || - (is_same::value && (N == 1)) || - (is_same::value && (N == 1)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1)) || - (is_same::value && (N == 1)), - "wrong! not implemented"); + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); if constexpr(is_same::value) { @@ -213,16 +237,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, { vector_type tmp; - tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( + tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - tmp.Vectors(Number<4>{})(Number<1>{}) = + tmp.AsType()(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(float), 0); - return tmp.Vector(); + return tmp.AsType()(Number<0>{}); } } else if constexpr(is_same::value) @@ -242,39 +266,20 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, return __llvm_amdgcn_raw_buffer_load_fp16x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - return __llvm_amdgcn_raw_buffer_load_fp16x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - return __llvm_amdgcn_raw_buffer_load_fp16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) + else if constexpr(N == 8) { vector_type tmp; - tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( + tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - tmp.Vectors(Number<4>{})(Number<1>{}) = + tmp.AsType()(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(half_t), 0); - return tmp.Vector(); + return tmp.AsType()(Number<0>{}); } } else if constexpr(is_same::value) @@ -298,32 +303,103 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, { vector_type tmp; - tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( + tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - tmp.Vectors(Number<4>{})(Number<1>{}) = + tmp.AsType()(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), 0); + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_i8( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX + return __llvm_amdgcn_raw_buffer_load_i8x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int16_t tmp = __llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return tmp.Vector(); + return as_type(tmp); +#endif } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) + else if constexpr(N == 4) { - return __llvm_amdgcn_raw_buffer_load_i32x2( +#if !CK_WORKAROUND_SWDEV_XXXXXX + return __llvm_amdgcn_raw_buffer_load_i8x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int32_t tmp = __llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) + else if constexpr(N == 8) { - return __llvm_amdgcn_raw_buffer_load_i32x4( +#if !CK_WORKAROUND_SWDEV_XXXXXX + vector_type tmp; + + tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x2_t tmp = __llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + else if constexpr(N == 16) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX + vector_type tmp; + + tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<2>{}) = + __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<3>{}) = + __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x4_t tmp = __llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif } } } @@ -407,23 +483,39 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type } else if constexpr(N == 2) { - __llvm_amdgcn_raw_buffer_store_i16(src_thread_data, +#if !CK_WORKAROUND_SWDEV_XXXXXX + __llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + __llvm_amdgcn_raw_buffer_store_i16(as_type(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); +#endif } else if constexpr(N == 4) { - __llvm_amdgcn_raw_buffer_store_i32(src_thread_data, +#if !CK_WORKAROUND_SWDEV_XXXXXX + __llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + __llvm_amdgcn_raw_buffer_store_i32(as_type(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); +#endif } else if constexpr(N == 8) { - __llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, + __llvm_amdgcn_raw_buffer_store_i32x2(as_type(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -431,7 +523,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type } else if constexpr(N == 16) { - __llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, + __llvm_amdgcn_raw_buffer_store_i32x4(as_type(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -468,13 +560,13 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type { vector_type tmp{src_thread_data}; - __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}], + __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); - __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}], + __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(half_t), @@ -488,26 +580,29 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type // 2) p_src_wave to be a wavewise pointer. // It is user's responsibility to make sure that is true. template -__device__ typename vector_type::type amd_buffer_load_v2(const T* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_element_space) +__device__ typename vector_type_maker::type::type +amd_buffer_load_v2(const T* p_src_wave, + index_t src_thread_data_offset, + bool src_thread_data_valid, + index_t src_element_space) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T); + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - return amd_buffer_load_impl_v2( + return amd_buffer_load_impl_v2( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - using vector_t = typename vector_type::type; - - vector_t tmp = - amd_buffer_load_impl_v2(src_wave_buffer_resource, src_thread_addr_offset, 0); + vector_t tmp = amd_buffer_load_impl_v2( + src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_data_valid ? tmp : vector_t(0); #endif @@ -518,26 +613,31 @@ __device__ typename vector_type::type amd_buffer_load_v2(const T* p_src_wa // 2) p_dst_wave to be a wavewise pointer. // It is user's responsibility to make sure that is true. template -__device__ void amd_buffer_store_v2(const typename vector_type::type src_thread_data, - T* p_dst_wave, - const index_t dst_thread_data_offset, - const bool dst_thread_data_valid, - const index_t dst_element_space) +__device__ void +amd_buffer_store_v2(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_data_offset, + const bool dst_thread_data_valid, + const index_t dst_element_space) { const int32x4_t dst_wave_buffer_resource = make_wave_buffer_resource(p_dst_wave, dst_element_space); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T); + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - amd_buffer_store_impl_v2( + amd_buffer_store_impl_v2( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_data_valid) { - amd_buffer_store_impl_v2( + amd_buffer_store_impl_v2( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index fa0f76e630..3c8b58193b 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo __device__ void amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) { + // TODO remove pointer casting const half2_t* p_a_half2 = reinterpret_cast(&a); const half2_t* p_b0_half2 = reinterpret_cast(&b0); const half2_t* p_b1_half2 = reinterpret_cast(&b1); @@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, float& c2, float& c3) { + // TODO remove pointer casting const half2_t* p_a_half2 = reinterpret_cast(&a); const half2_t* p_b0_half2 = reinterpret_cast(&b0); const half2_t* p_b1_half2 = reinterpret_cast(&b1); @@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a, float& c3) { + // TODO remove pointer casting const half4_t* p_a_half4 = reinterpret_cast(&a); const half4_t* p_b0_half4 = reinterpret_cast(&b0); const half4_t* p_b1_half4 = reinterpret_cast(&b1); @@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a, float& c2, float& c3) { + // TODO remove pointer casting const half8_t* p_a_half8 = reinterpret_cast(&a); const half8_t* p_b0_half8 = reinterpret_cast(&b0); const half8_t* p_b1_half8 = reinterpret_cast(&b1); @@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 v_dot4_i32_i8 %1, %2, %4, %1\n \ " : "=v"(c0), "=v"(c1) - : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); + : "v"(as_type(a)), + "v"(as_type(b0)), + "v"(as_type(b1)), + "0"(c0), + "1"(c1)); #else - c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); - c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); + c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); #endif } @@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, v_dot4_i32_i8 %3, %4, %8, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) - : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); + : "v"(as_type(a)), + "v"(as_type(b0)), + "v"(as_type(b1)), + "v"(as_type(b2)), + "v"(as_type(b3)), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); - c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); - c2 = __builtin_amdgcn_sdot4(a, b2, c2, false); - c3 = __builtin_amdgcn_sdot4(a, b3, c3, false); + c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(as_type(a), as_type(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(as_type(a), as_type(b3), c3, false); #endif } @@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a, int32_t& c2, int32_t& c3) { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; - const int8x4_t* p_a_int8x4_t = reinterpret_cast(&a); - const int8x4_t* p_b0_int8x4_t = reinterpret_cast(&b0); - const int8x4_t* p_b1_int8x4_t = reinterpret_cast(&b1); - const int8x4_t* p_b2_int8x4_t = reinterpret_cast(&b2); - const int8x4_t* p_b3_int8x4_t = reinterpret_cast(&b3); - - amd_assembly_outer_product_1x4(p_a_int8x4_t[0], - p_b0_int8x4_t[0], - p_b1_int8x4_t[0], - p_b2_int8x4_t[0], - p_b3_int8x4_t[0], + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], c0, c1, c2, c3); - amd_assembly_outer_product_1x4(p_a_int8x4_t[1], - p_b0_int8x4_t[1], - p_b1_int8x4_t[1], - p_b2_int8x4_t[1], - p_b3_int8x4_t[1], + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], c0, c1, c2, @@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, int32_t& c3) { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - const int8x8_t* p_a_int8x8_t = reinterpret_cast(&a); - const int8x8_t* p_b0_int8x8_t = reinterpret_cast(&b0); - const int8x8_t* p_b1_int8x8_t = reinterpret_cast(&b1); - const int8x8_t* p_b2_int8x8_t = reinterpret_cast(&b2); - const int8x8_t* p_b3_int8x8_t = reinterpret_cast(&b3); - - amd_assembly_outer_product_1x4(p_a_int8x8_t[0], - p_b0_int8x8_t[0], - p_b1_int8x8_t[0], - p_b2_int8x8_t[0], - p_b3_int8x8_t[0], + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], c0, c1, c2, c3); - amd_assembly_outer_product_1x4(p_a_int8x8_t[1], - p_b0_int8x8_t[1], - p_b1_int8x8_t[1], - p_b2_int8x8_t[1], - p_b3_int8x8_t[1], + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I2], + vector_type{b0}.AsType()[I2], + vector_type{b1}.AsType()[I2], + vector_type{b2}.AsType()[I2], + vector_type{b3}.AsType()[I2], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I3], + vector_type{b0}.AsType()[I3], + vector_type{b1}.AsType()[I3], + vector_type{b2}.AsType()[I3], + vector_type{b3}.AsType()[I3], c0, c1, c2, diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index bca451a60a..9de35587fd 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -14,11 +14,11 @@ #define CK_DEVICE_BACKEND_AMD 1 // GPU ID -#if 0 +#if 1 #define CK_AMD_GPU_GFX906 1 #elif 0 #define CK_AMD_GPU_GFX908 1 -#elif 1 +#elif 0 #define CK_AMD_GPU_GFX1030 1 #endif @@ -88,7 +88,7 @@ // experimental implementation #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 #endif #ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK @@ -142,6 +142,11 @@ #define CK_WORKAROUND_SWDEV_275126 1 #endif +// workaround for compiler crash when using buffer load/store for i8 +#ifndef CK_WORKAROUND_SWDEV_XXXXXX +#define CK_WORKAROUND_SWDEV_XXXXXX 1 +#endif + namespace ck { enum AddressSpace diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.amd.hpp.in index 7ce0d18d61..f957f9aaa7 100644 --- a/composable_kernel/include/utility/float_type.amd.hpp.in +++ b/composable_kernel/include/utility/float_type.amd.hpp.in @@ -3,13 +3,106 @@ namespace ck { +using half_t = _Float16; + +// vector_type template struct vector_type; +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +// scalar_type +template +struct scalar_type; + +template +struct scalar_type +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +// +template <> +struct scalar_type +{ + using type = float; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = half_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = ushort; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int32_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int8_t; + static constexpr index_t vector_size = 1; +}; + +// template struct vector_type { - using type = T; + using d1_t = T; + using type = d1_t; union { @@ -21,19 +114,21 @@ struct vector_type __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr index_t Size() { return 1; } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, "wrong!"); - __host__ __device__ constexpr const auto& Vector() const { return data_.d1_; } + return data_.d1x1_; + } - __host__ __device__ constexpr auto& Vector() { return data_.d1_; } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, "wrong!"); - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x1_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x1_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x1_; } + return data_.d1x1_; + } }; template @@ -55,23 +150,35 @@ struct vector_type __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr index_t Size() { return 2; } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, "wrong!"); - __host__ __device__ constexpr const auto& Vector() const { return data_.d2_; } + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } - __host__ __device__ constexpr auto& Vector() { return data_.d2_; } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, "wrong!"); - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; } + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } }; template @@ -95,27 +202,45 @@ struct vector_type __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr index_t Size() { return 4; } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); - __host__ __device__ constexpr const auto& Vector() const { return data_.d4_; } + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } - __host__ __device__ constexpr auto& Vector() { return data_.d4_; } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; } + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } }; template @@ -141,31 +266,55 @@ struct vector_type __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr index_t Size() { return 8; } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); - __host__ __device__ constexpr const auto& Vector() const { return data_.d8_; } + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } - __host__ __device__ constexpr auto& Vector() { return data_.d8_; } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; } - - __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } }; template @@ -193,35 +342,65 @@ struct vector_type __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr index_t Size() { return 16; } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); - __host__ __device__ constexpr const auto& Vector() const { return data_.d16_; } + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } - __host__ __device__ constexpr auto& Vector() { return data_.d16_; } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; } - - __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; } - - __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; } - - __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; } + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } }; // fp32 @@ -230,7 +409,6 @@ using float4_t = typename vector_type::type; using float8_t = typename vector_type::type; // fp16 -using half_t = _Float16; using half2_t = typename vector_type::type; using half4_t = typename vector_type::type; using half8_t = typename vector_type::type; @@ -246,197 +424,8 @@ using int32x2_t = typename vector_type::type; using int32x4_t = typename vector_type::type; using int32x8_t = typename vector_type::type; -template <> -struct vector_type -{ - using d1_t = int8_t; - typedef int16_t d2_t; - - using type = d2_t; - - union - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - __host__ __device__ static constexpr index_t Size() { return 2; } - - __host__ __device__ constexpr const auto& Vector() const { return data_.d2_; } - - __host__ __device__ constexpr auto& Vector() { return data_.d2_; } - - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; } -}; - -template <> -struct vector_type -{ - using d1_t = int8_t; - typedef int16_t d2_t; - typedef int32_t d4_t; - - using type = d4_t; - - union - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - __host__ __device__ static constexpr index_t Size() { return 4; } - - __host__ __device__ constexpr const auto& Vector() const { return data_.d4_; } - - __host__ __device__ constexpr auto& Vector() { return data_.d4_; } - - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; } -}; - -template <> -struct vector_type -{ - using d1_t = int8_t; - typedef int16_t d2_t; - typedef int32_t d4_t; - typedef int32x2_t d8_t; - - using type = d8_t; - - union - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - __host__ __device__ static constexpr index_t Size() { return 8; } - - __host__ __device__ constexpr const auto& Vector() const { return data_.d8_; } - - __host__ __device__ constexpr auto& Vector() { return data_.d8_; } - - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; } - - __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } -}; - -template <> -struct vector_type -{ - using d1_t = int8_t; - typedef int16_t d2_t; - typedef int32_t d4_t; - typedef int32x2_t d8_t; - typedef int32x4_t d16_t; - - using type = d16_t; - - union - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - __host__ __device__ static constexpr index_t Size() { return 16; } - - __host__ __device__ constexpr const auto& Vector() const { return data_.d16_; } - - __host__ __device__ constexpr auto& Vector() { return data_.d16_; } - - __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; } - - __host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; } - - __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; } - - __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; } - - __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; } - - __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; } - - __host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; } - - __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; } - - __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; } - - __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; } - - __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; } - - __host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; } -}; - // i8 -// hack for int8x4_t, because compiler does not have native support for int8x4_t -// int8x4_t is defined as int32_t +using int8x2_t = typename vector_type::type; using int8x4_t = typename vector_type::type; using int8x8_t = typename vector_type::type; using int8x16_t = typename vector_type::type; @@ -489,8 +478,6 @@ struct inner_product_with_conversion __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } - // hack for int8x4_t, because compiler does not have native support for int8x4_t - // int8x4_t is defined as int32_t __device__ T operator()(int8x4_t a, int8x4_t b) const { const vector_type a_vector{a}; @@ -499,7 +486,7 @@ struct inner_product_with_conversion T acc = 0; static_for<0, 4, 1>{}([&](auto i) { - acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); }); return acc; @@ -513,7 +500,7 @@ struct inner_product_with_conversion T acc = 0; static_for<0, 8, 1>{}([&](auto i) { - acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); }); return acc; @@ -527,7 +514,7 @@ struct inner_product_with_conversion T acc = 0; static_for<0, 16, 1>{}([&](auto i) { - acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); }); return acc; diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 53a8e7ac4b..1aa187dfcf 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); -#if 0 +#if 1 // run-time variables const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); @@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; -#elif 1 +#elif 0 // cdata = 64, BlockSize 64, 16x256x4 constexpr index_t BlockSize = 64; diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp index fee97dddbc..ccb8b29a77 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( constexpr auto C0 = C / Number{}; constexpr auto C1 = Number{}; -#if 0 +#if 1 // run-time variables constexpr auto in_n_hi_wi_c0_desc = make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); @@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); -#if 1 +#if 0 // cdata = 16, BlockSize = 64, 16x64x4 constexpr index_t BlockSize = 64; @@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; -#elif 1 +#elif 0 // cdata = 64, BlockSize = 64, 16x256x4 constexpr index_t BlockSize = 64; @@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; -#elif 0 +#elif 1 // cdata = 64, BlockSize = 256, 128x128x8 constexpr index_t BlockSize = 256; diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp index 8d3c0d10b1..60ebe76da4 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp @@ -83,10 +83,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( const auto out_n_k0_ho_wo_k1_desc = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); - const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); - const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); - const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); - const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); + const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); + const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); + const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); + const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); #endif Tensor in_n_c0_hi_wi_c1(make_HostTensorDescriptor( diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index 1e9487287d..2f490a323f 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -48,8 +48,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 16; @@ -62,9 +62,9 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 1 + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 constexpr index_t N = 1; constexpr index_t C = 16; constexpr index_t HI = 1080; @@ -150,7 +150,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 0 +#elif 1 // 3x3, 71x71 constexpr index_t N = 128; constexpr index_t C = 192; @@ -630,7 +630,7 @@ int main(int argc, char* argv[]) print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvDilations", to_multi_index(ConvDilations{})); -#if 0 +#if 1 using in_data_t = float; constexpr index_t in_vector_size = 1; using acc_data_t = float; @@ -724,7 +724,7 @@ int main(int argc, char* argv[]) LeftPads{}, RightPads{}, nrepeat); -#elif 0 +#elif 1 device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw