From 562e1e27673bea4d9ce6793d418c7788138e49ed Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 4 Nov 2019 16:51:12 -0600 Subject: [PATCH] MIOpen integration: recent bug fixes from MIOpen (#5) --- .../tensor_operation/threadwise_gemm.hpp | 20 +- .../threadwise_generic_tensor_slice_copy.hpp | 8 +- ...e_generic_tensor_slice_copy_deprecated.hpp | 8 +- .../include/utility/amd_buffer_addressing.hpp | 67 ++-- .../include/utility/amd_inline_asm.hpp | 82 ++--- .../include/utility/common_header.hpp | 4 + .../include/utility/config.amd.hpp.in | 4 +- .../include/utility/float_type.hpp | 311 ++++++++++++++++++ driver/src/driver.cpp | 8 +- 9 files changed, 415 insertions(+), 97 deletions(-) create mode 100644 composable_kernel/include/utility/float_type.hpp diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm.hpp index 00d81410e9..7cfd54e050 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm.hpp @@ -114,7 +114,7 @@ struct ThreadwiseGemmTransANormalBNormalC const index_t cindex_0 = MatrixC::CalculateOffset(m, 0); const index_t cindex_1 = MatrixC::CalculateOffset(m, 1); - __outer_product_1x2( + amd_assembly_outer_product_1x2( p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]); }); @@ -129,15 +129,15 @@ struct ThreadwiseGemmTransANormalBNormalC const index_t cindex_2 = MatrixC::CalculateOffset(m, 2); const index_t cindex_3 = MatrixC::CalculateOffset(m, 3); - __outer_product_1x4(p_a[aindex], - p_b[bindex_0], - p_b[bindex_1], - p_b[bindex_2], - p_b[bindex_3], - p_c[cindex_0], - p_c[cindex_1], - p_c[cindex_2], - p_c[cindex_3]); + amd_assembly_outer_product_1x4(p_a[aindex], + p_b[bindex_0], + p_b[bindex_1], + p_b[bindex_2], + p_b[bindex_3], + p_c[cindex_0], + p_c[cindex_1], + p_c[cindex_2], + p_c[cindex_3]); }); } } diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index 1e3095d72e..6a98c7836f 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -123,7 +123,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 static_if{}([&](auto fwd) { #if CK_USE_AMD_BUFFER_ADDRESSING *reinterpret_cast(&p_src_long_vector[buffer_offset]) = - __buffer_load( + amd_intrinsic_buffer_load( fwd(p_src), src_coord.GetOffset(), 0); #else *reinterpret_cast(&p_src_long_vector[buffer_offset]) = @@ -162,7 +162,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 { static_if{}([&](auto fwd) { #if CK_USE_AMD_BUFFER_ADDRESSING - __buffer_store( + amd_intrinsic_buffer_store( *reinterpret_cast(&p_dst_long_vector[buffer_offset]), fwd(p_dst), dst_coord.GetOffset(), @@ -311,7 +311,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 static_if{}([&](auto) { #if CK_USE_AMD_BUFFER_ADDRESSING *reinterpret_cast(&p_src_long_vector[buffer_offset]) = - __buffer_load( + amd_intrinsic_buffer_load( p_src, src_nonlinear_coord.GetOffset(), src_linear_offset); #else *reinterpret_cast(&p_src_long_vector[buffer_offset]) = @@ -503,7 +503,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 { static_if{}([&](auto) { #if CK_USE_AMD_BUFFER_ADDRESSING - __buffer_store( + amd_intrinsic_buffer_store( *reinterpret_cast(&p_dst_long_vector[buffer_offset]), p_dst, dst_nonlinear_coord.GetOffset(), diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp index f28ac1892c..f28ef935b1 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp @@ -335,7 +335,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated // 3. src_merged_offset can be runtime value (no assumption imposed) static_if{}([&](auto fwd) { #if CK_USE_AMD_BUFFER_ADDRESSING - vector_data = __buffer_load( + vector_data = amd_intrinsic_buffer_load( fwd(p_src), src_merged_offset, src_normal_offset); #else vector_data = *reinterpret_cast( @@ -375,7 +375,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated // copy data from buffer into dst { - using dst_vector_t = typename vector_type::MemoryType; + using dst_vector_t = typename vector_type::MemoryType; constexpr auto dst_vector_access_dim = Number{}; constexpr auto dst_data_per_access = Number{}; @@ -420,7 +420,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex( dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id); - reinterpret_cast(&vector_data)[i] = p_dst_buffer[buffer_offset]; + reinterpret_cast(&vector_data)[i] = p_dst_buffer[buffer_offset]; } // offset w.r.t. normal dimension is known at compile-time @@ -444,7 +444,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated // 3. dst_merged_offset can be runtime value (no assumption imposed) static_if{}([&](auto fwd) { #if CK_USE_AMD_BUFFER_ADDRESSING - __buffer_store( + amd_intrinsic_buffer_store( vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset); #else *reinterpret_cast( diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index 4bb6f26935..27b72206de 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -19,55 +19,56 @@ __device__ float __llvm_amdgcn_buffer_load(int32x4_t rsrc, index_t vindex, index_t offset, bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load"); + bool slc) __asm("llvm.amdgcn.buffer.load.f32"); __device__ float2_t __llvm_amdgcn_buffer_loadx2(int32x4_t rsrc, index_t vindex, index_t offset, bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.dwordx2"); + bool slc) __asm("llvm.amdgcn.buffer.load.v2f32"); __device__ float4_t __llvm_amdgcn_buffer_loadx4(int32x4_t rsrc, index_t vindex, index_t offset, bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.dwordx4"); + bool slc) __asm("llvm.amdgcn.buffer.load.v4f32"); __device__ void __llvm_amdgcn_buffer_store(float vdata, int32x4_t rsrc, index_t vindex, index_t offset, bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store"); + bool slc) __asm("llvm.amdgcn.buffer.store.f32"); __device__ void __llvm_amdgcn_buffer_storex2(float2_t vdata, int32x4_t rsrc, index_t vindex, index_t offset, bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.dwordx2"); + bool slc) __asm("llvm.amdgcn.buffer.store.v2f32"); __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata, int32x4_t rsrc, index_t vindex, index_t offset, bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.dwordx4"); + bool slc) __asm("llvm.amdgcn.buffer.store.v4f32"); template -__device__ typename vector_type::MemoryType -__buffer_load(const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset); +__device__ typename vector_type::MemoryType amd_intrinsic_buffer_load( + const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset); template -__device__ void __buffer_store(const typename vector_type::MemoryType& src, - T* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset); +__device__ void +amd_intrinsic_buffer_store(const typename vector_type::MemoryType& src, + T* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset); template <> -__device__ float __buffer_load(const float* p_src_block, - index_t src_thread_data_offset, - index_t src_const_data_offset) +__device__ float amd_intrinsic_buffer_load(const float* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) { float dst; @@ -100,9 +101,9 @@ __device__ float __buffer_load(const float* p_src_block, } template <> -__device__ float2_t __buffer_load(const float* p_src_block, - index_t src_thread_data_offset, - index_t src_const_data_offset) +__device__ float2_t amd_intrinsic_buffer_load(const float* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) { float2_t dst; @@ -135,9 +136,9 @@ __device__ float2_t __buffer_load(const float* p_src_block, } template <> -__device__ float4_t __buffer_load(const float* p_src_block, - index_t src_thread_data_offset, - index_t src_const_data_offset) +__device__ float4_t amd_intrinsic_buffer_load(const float* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) { float4_t dst; @@ -170,10 +171,10 @@ __device__ float4_t __buffer_load(const float* p_src_block, } template <> -__device__ void __buffer_store(const float& src, - float* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset) +__device__ void amd_intrinsic_buffer_store(const float& src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) { index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); @@ -207,10 +208,10 @@ __device__ void __buffer_store(const float& src, } template <> -__device__ void __buffer_store(const float2_t& src, - float* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset) +__device__ void amd_intrinsic_buffer_store(const float2_t& src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) { index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); @@ -244,10 +245,10 @@ __device__ void __buffer_store(const float2_t& src, } template <> -__device__ void __buffer_store(const float4_t& src, - float* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset) +__device__ void amd_intrinsic_buffer_store(const float4_t& src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) { index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index 28eaf1f443..7be6b9fe46 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -6,7 +6,7 @@ namespace ck { // outer-product: c[i,j] += inner_product(a[i], b[j]) -__device__ void __outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) +__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) { // disable inline asm due to the compiler issue: SWDEV-202749 ///\to-do: enable the inline asm after the compiler fix @@ -24,7 +24,7 @@ __device__ void __outer_product_1x2(float a, float b0, float b1, float& c0, floa } // outer-product: c[i,j] += inner_product(a[i], b[j]) -__device__ void __outer_product_1x4( +__device__ void amd_assembly_outer_product_1x4( float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) { asm volatile("\n \ @@ -38,11 +38,12 @@ __device__ void __outer_product_1x4( } // outer-product: c[i,j] += inner_product(a[i], b[j]) -__device__ void __outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) +__device__ void +amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) { asm volatile("\n \ - v_dot2_f32_f16 %0, %2, %3 %0\n \ - v_dot2_f32_f16 %1, %2, %4 %1\n \ + v_dot2_f32_f16 %0, %2, %3, %0\n \ + v_dot2_f32_f16 %1, %2, %4, %1\n \ " : "=v"(c0), "=v"(c1) // Dest registers : "v"(a), // 1st Src register for 1 half2 registers @@ -53,7 +54,8 @@ __device__ void __outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0 } // outer-product: c[i,j] += inner_product(a[i], b[j]) -__device__ void __outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) +__device__ void +amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) { const half2_t* p_a_half2 = reinterpret_cast(&a); const half2_t* p_b0_half2 = reinterpret_cast(&b0); @@ -61,10 +63,10 @@ __device__ void __outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0 // do dot2 two times asm volatile("\n \ - v_dot2_f32_f16 %0, %2, %4 %0\n \ - v_dot2_f32_f16 %1, %2, %6 %1\n \ - v_dot2_f32_f16 %0, %3, %5 %0\n \ - v_dot2_f32_f16 %1, %3, %7 %1\n \ + v_dot2_f32_f16 %0, %2, %4, %0\n \ + v_dot2_f32_f16 %1, %2, %6, %1\n \ + v_dot2_f32_f16 %0, %3, %5, %0\n \ + v_dot2_f32_f16 %1, %3, %7, %1\n \ " : "=v"(c0), "=v"(c1) // Dest registers : "v"(p_a_half2[0]), @@ -78,21 +80,21 @@ __device__ void __outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0 } // outer-product: c[i,j] += inner_product(a[i], b[j]) -__device__ void __outer_product_1x4(half2_t a, - half2_t b0, - half2_t b1, - half2_t b2, - half2_t b3, - float& c0, - float& c1, - float& c2, - float& c3) +__device__ void amd_assembly_outer_product_1x4(half2_t a, + half2_t b0, + half2_t b1, + half2_t b2, + half2_t b3, + float& c0, + float& c1, + float& c2, + float& c3) { asm volatile("\n \ - v_dot2_f32_f16 %0, %4, %5 %0\n \ - v_dot2_f32_f16 %1, %4, %6 %1\n \ - v_dot2_f32_f16 %2, %4, %7 %2\n \ - v_dot2_f32_f16 %3, %4, %8 %3\n \ + v_dot2_f32_f16 %0, %4, %5, %0\n \ + v_dot2_f32_f16 %1, %4, %6, %1\n \ + v_dot2_f32_f16 %2, %4, %7, %2\n \ + v_dot2_f32_f16 %3, %4, %8, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "v"(a), // 1st Src register for 1 half2 registers @@ -107,15 +109,15 @@ __device__ void __outer_product_1x4(half2_t a, } // outer-product: c[i,j] += inner_product(a[i], b[j]) -__device__ void __outer_product_1x4(half4_t a, - half4_t b0, - half4_t b1, - half4_t b2, - half4_t b3, - float& c0, - float& c1, - float& c2, - float& c3) +__device__ void amd_assembly_outer_product_1x4(half4_t a, + half4_t b0, + half4_t b1, + half4_t b2, + half4_t b3, + float& c0, + float& c1, + float& c2, + float& c3) { const half2_t* p_a_half2 = reinterpret_cast(&a); const half2_t* p_b0_half2 = reinterpret_cast(&b0); @@ -125,14 +127,14 @@ __device__ void __outer_product_1x4(half4_t a, // do dot2 two times asm volatile("\n \ - v_dot2_f32_f16 %0, %4, %6 %0\n \ - v_dot2_f32_f16 %1, %4, %8 %1\n \ - v_dot2_f32_f16 %2, %4, %10 %2\n \ - v_dot2_f32_f16 %3, %4, %12 %3\n \ - v_dot2_f32_f16 %0, %5, %7 %0\n \ - v_dot2_f32_f16 %1, %5, %9 %1\n \ - v_dot2_f32_f16 %2, %5, %11 %2\n \ - v_dot2_f32_f16 %3, %5, %13 %3\n \ + v_dot2_f32_f16 %0, %4, %6, %0\n \ + v_dot2_f32_f16 %1, %4, %8, %1\n \ + v_dot2_f32_f16 %2, %4, %10, %2\n \ + v_dot2_f32_f16 %3, %4, %12, %3\n \ + v_dot2_f32_f16 %0, %5, %7, %0\n \ + v_dot2_f32_f16 %1, %5, %9, %1\n \ + v_dot2_f32_f16 %2, %5, %11, %2\n \ + v_dot2_f32_f16 %3, %5, %13, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers : "v"(p_a_half2[0]), diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 588efca083..00964e0fe1 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -24,4 +24,8 @@ #include "amd_buffer_addressing.hpp" #endif +#if CK_USE_AMD_XDLOPS +#include "amd_xdlops.hpp" +#endif + #endif diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index 3e19b56769..896679c856 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -31,11 +31,11 @@ // AMD XDLOPS #ifndef CK_USE_AMD_XDLOPS -#define CK_USE_AMD_XDLOPS 1 +#define CK_USE_AMD_XDLOPS 0 #endif #ifndef CK_USE_AMD_XDLOPS_INLINE_ASM -#define CK_USE_AMD_XDLOPS_INLINE_ASM 1 +#define CK_USE_AMD_XDLOPS_INLINE_ASM 0 #endif // experimental implementation diff --git a/composable_kernel/include/utility/float_type.hpp b/composable_kernel/include/utility/float_type.hpp new file mode 100644 index 0000000000..fd9c0029bc --- /dev/null +++ b/composable_kernel/include/utility/float_type.hpp @@ -0,0 +1,311 @@ +#ifndef CK_FLOAT_TYPE_AMD_HPP +#define CK_FLOAT_TYPE_AMD_HPP + +namespace ck { + +// For some reason, HIP compiler need this definition to generate optimal ISA +// float +typedef float float2_t __attribute__((ext_vector_type(2))); +typedef float float4_t __attribute__((ext_vector_type(4))); +typedef float float16_t __attribute__((ext_vector_type(16))); +typedef float float32_t __attribute__((ext_vector_type(32))); + +// float16 +typedef _Float16 half2_t __attribute__((ext_vector_type(2))); +typedef _Float16 half4_t __attribute__((ext_vector_type(4))); + +// bfloat16 +typedef ushort ushort2_t __attribute__((ext_vector_type(2))); +typedef ushort ushort4_t __attribute__((ext_vector_type(4))); + +template +struct vector_type +{ + typedef struct + { + T scalar[N]; + } MemoryType; +}; + +template <> +struct vector_type +{ + using MemoryType = float; + + template + __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + { + static_assert(I < 1, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } +}; + +template <> +struct vector_type +{ + using MemoryType = float2_t; + + union DataType + { + MemoryType vector; + float scalar[2]; + }; + + template + __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + { + static_assert(I < 2, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } + + __host__ __device__ static MemoryType Pack(float s0, float s1) + { + DataType data; + data.scalar[0] = s0; + data.scalar[1] = s1; + return data.vector; + } +}; + +template <> +struct vector_type +{ + using MemoryType = float4_t; + + __host__ __device__ static constexpr index_t GetSize() { return 4; } + + template + __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + { + static_assert(I < 4, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } +}; + +template <> +struct vector_type +{ + using MemoryType = half; + + template + __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + { + static_assert(I < 1, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } +}; + +template <> +struct vector_type +{ + using MemoryType = half2_t; + + union DataType + { + MemoryType vector; + half scalar[2]; + }; + + template + __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + { + static_assert(I < 2, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } + + __host__ __device__ static MemoryType Pack(half s0, half s1) + { + DataType data; + data.scalar[0] = s0; + data.scalar[1] = s1; + return data.vector; + } +}; + +template <> +struct vector_type +{ + using MemoryType = half4_t; + + union DataType + { + MemoryType vector; + half scalar[4]; + }; + + template + __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + { + static_assert(I < 4, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } + + __host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3) + { + DataType data; + data.scalar[0] = s0; + data.scalar[1] = s1; + data.scalar[2] = s2; + data.scalar[3] = s3; + return data.vector; + } +}; + +template <> +struct vector_type +{ + using MemoryType = ushort; + + template + __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) + { + static_assert(I < 1, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } +}; + +template <> +struct vector_type +{ + using MemoryType = ushort2_t; + + union DataType + { + MemoryType vector; + ushort scalar[2]; + }; + + template + __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) + { + static_assert(I < 2, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } + + __host__ __device__ static MemoryType Pack(ushort s0, ushort s1) + { + DataType data; + data.scalar[0] = s0; + data.scalar[1] = s1; + return data.vector; + } +}; + +template <> +struct vector_type +{ + using MemoryType = ushort4_t; + + union DataType + { + MemoryType vector; + ushort scalar[4]; + }; + + template + __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) + { + static_assert(I < 4, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } + + __host__ __device__ static MemoryType Pack(ushort s0, ushort s1, ushort s2, ushort s3) + { + DataType data; + data.scalar[0] = s0; + data.scalar[1] = s1; + data.scalar[2] = s2; + data.scalar[3] = s3; + return data.vector; + } +}; + +// data type conversion +template +struct type_convert +{ + template + __device__ T operator()(X x) const + { + return static_cast(x); + } +}; + +template <> +template <> +__device__ float type_convert::operator()(ushort x) const +{ + return bfloat16_to_float(x); +} + +template <> +template <> +__device__ ushort type_convert::operator()(float x) const +{ + return float_to_bfloat16(x); +} + +template +struct inner_product_with_conversion +{ + static constexpr auto convert = type_convert(); + + __device__ T operator()(float a, float b) const { return convert(a) * convert(b); } + + __device__ T operator()(half2_t a, half2_t b) const + { + const half* p_a_half = reinterpret_cast(&a); + const half* p_b_half = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 2; ++v) + { + acc += convert(p_a_half[v]) * convert(p_b_half[v]); + } + + return acc; + } + + __device__ T operator()(half4_t a, half4_t b) const + { + const half* p_a_half = reinterpret_cast(&a); + const half* p_b_half = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 4; ++v) + { + acc += convert(p_a_half[v]) * convert(p_b_half[v]); + } + return acc; + } + + __device__ T operator()(ushort2_t a, ushort2_t b) const + { + const ushort* p_a_bfloat16 = reinterpret_cast(&a); + const ushort* p_b_bfloat16 = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 2; ++v) + { + acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); + } + + return acc; + } + + __device__ T operator()(ushort4_t a, ushort4_t b) const + { + const ushort* p_a_bfloat16 = reinterpret_cast(&a); + const ushort* p_b_bfloat16 = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 4; ++v) + { + acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); + } + return acc; + } +}; + +} // namespace ck +#endif diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index dccad8a5ee..720d592006 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -297,7 +297,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 1 +#elif 0 // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% constexpr index_t N = 128; @@ -343,7 +343,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>; -#elif 0 +#elif 1 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; constexpr index_t C = 128; @@ -450,7 +450,7 @@ int main(int argc, char* argv[]) ConvStrides{}, ConvDilations{}, nrepeat); -#elif 0 +#elif 1 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc, @@ -492,7 +492,7 @@ int main(int argc, char* argv[]) ConvStrides{}, ConvDilations{}, nrepeat); -#elif 1 +#elif 0 device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,