diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 5bc004427c..68b4db1a43 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -12,18 +12,19 @@ enum struct MfmaInstr mfma_f32_32x32x1xf32 = 0, mfma_f32_16x16x1xf32, mfma_f32_4x4x1xf32, - mfma_f32_32x32x2xf32, // k reduction - mfma_f32_16x16x4xf32, // k reduction + mfma_f32_32x32x2xf32, + mfma_f32_16x16x4xf32, mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, - mfma_f32_32x32x8f16, // k reduction - mfma_f32_16x16x16f16, // k reduction - mfma_f32_32x32x2bf16, - mfma_f32_16x16x2bf16, - mfma_f32_4x4x2bf16, - mfma_f32_32x32x4bf16, // k reduction - mfma_f32_16x16x8bf16, // k reduction + mfma_f32_32x32x8f16, + mfma_f32_16x16x16f16, + mfma_f32_32x32x8bf16_1k, + mfma_f32_16x16x16bf16_1k, + mfma_f32_32x32x4bf16, + mfma_f32_16x16x8bf16, + mfma_i32_32x32x8i8, + mfma_i32_16x16x16i8, }; template @@ -250,9 +251,8 @@ struct mfma_type } }; -#if 0 template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -260,26 +260,38 @@ struct mfma_type static constexpr index_t num_threads_per_blk = 32; static constexpr index_t wave_size = 64; static constexpr index_t num_input_blks = 2; - static constexpr index_t num_output_blks = 2; + static constexpr index_t num_output_blks = 1; static constexpr index_t m_per_blk = 32; static constexpr index_t n_per_blk = 32; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); + intrin_mfma_f32_32x32x8bf16_1k::Run(a, b, reg_c); + } +}; - return intrin_mfma_f32_32x32x2bf16::run( - p_a, p_b, reg_c); +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16bf16_1k::Run(a, b, reg_c); } }; @@ -298,19 +310,10 @@ struct mfma_type static constexpr index_t k_per_blk = 2; static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); + intrin_mfma_f32_32x32x4bf16::Run(a, b, reg_c); } }; @@ -329,24 +332,37 @@ struct mfma_type static constexpr index_t k_per_blk = 2; static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); + intrin_mfma_f32_16x16x8bf16::Run(a, b, reg_c); } }; template <> -struct mfma_type +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x8i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -354,60 +370,19 @@ struct mfma_type static constexpr index_t num_threads_per_blk = 16; static constexpr index_t wave_size = 64; static constexpr index_t num_input_blks = 4; - static constexpr index_t num_output_blks = 4; + static constexpr index_t num_output_blks = 1; static constexpr index_t m_per_blk = 16; static constexpr index_t n_per_blk = 16; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_16x16x2bf16(p_a, p_b, reg_c); + intrin_mfma_i32_16x16x16i8::Run(a, b, reg_c); } }; -template <> -struct mfma_type -{ - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_per_blk = 1; - static constexpr index_t num_regs_per_blk = 4; - static constexpr index_t num_threads_per_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t m_per_blk = 4; - static constexpr index_t n_per_blk = 64; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = false; - - template - __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const - { - const auto p_a = c_style_pointer_cast(a); - const auto p_b = c_style_pointer_cast(b); - - return intrin_mfma_f32_4x4x2bf16::run(p_a, p_b, reg_c); - } -}; -#endif - template struct MfmaSelector { @@ -498,73 +473,37 @@ struct MfmaSelector return MfmaInstr::mfma_f32_4x4x4f16; } -#if 0 - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetMfma() - { - return xdlops_info{}; - } - template <> static constexpr auto GetMfma() { - return xdlops_info{}; +#if defined(CK_AMD_GPU_GFX90A) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif } template <> static constexpr auto GetMfma() { - return xdlops_info{}; - } +#if defined(CK_AMD_GPU_GFX90A) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; #endif + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x8i8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x16i8; + } static constexpr auto selected_mfma = mfma_type()>{}; @@ -686,8 +625,8 @@ struct XdlopsGemm __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || - is_same::value, - "base base_type must be float, half, ushort!"); + is_same::value || is_same::value, + "base base_type must be float, half, ushort, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index 3df53bda44..c481df180b 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); -__device__ int16_t +__device__ ushort llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); + +__device__ ushort2_t +llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); + +__device__ ushort4_t +llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); + __device__ int32_t llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, index_t voffset, @@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); __device__ void -llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, +llvm_amdgcn_raw_buffer_store_i16(ushort vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); +__device__ void +llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + __device__ void llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, int32x4_t rsrc, @@ -228,6 +255,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -326,6 +354,31 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w return as_type(tmp); } } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); + } + } else if constexpr(is_same::value) { if constexpr(N == 1) @@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src (is_same::value && (N == 1 || N == 2)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); @@ -552,6 +606,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src 0); } } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); + } + } else if constexpr(is_same::value) { if constexpr(N == 1) diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp index 083e47fbf1..a87c42ddd7 100644 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -6,6 +6,7 @@ namespace ck { // A, B, C, cbsz, abid, blgp +// fp32 extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); @@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); +// fp16 extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); @@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); +// bfp16 +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k( + ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k( + ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k"); + extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); @@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); +// int8 +extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8( + int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8"); + +extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8( + int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8"); + +extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8( + int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8"); + +extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8( + int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8"); + +extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8( + int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8"); + +// fp32 template struct intrin_mfma_f32_32x32x1f32; @@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64> } }; +// fp16 template struct intrin_mfma_f32_32x32x4f16; @@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> } }; -#if 0 -template -struct intrin_mfma_f32_32x32x2bf16; +// bfp16 +template +struct intrin_mfma_f32_32x32x8bf16_1k; -template -struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> +template <> +struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> { - __device__ static c_vec32_4_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) + template + __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - reg_c.s.z = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0); - reg_c.s.w = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0); - - return reg_c; + reg_c.template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template -struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> -{ - __device__ static c_vec32_4_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - reg_c.s.z = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0); - reg_c.s.w = - llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> -{ - __device__ static c_vec32_2_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> -{ - __device__ static c_vec32_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); - - return reg_c; - } -}; - -template -struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> -{ - __device__ static c_vec32_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) - { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); - return reg_c; - } -}; - -__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); - return reg_c; -} - -__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec4_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); - return reg_c; -} - template -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c); -template <> -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) -{ - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); - return reg_c; -} +struct intrin_mfma_f32_16x16x16bf16_1k; template <> -__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, - const ushort2_t* reg_b, - c_vec16_1_t::VecType reg_c) +struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); - return reg_c; -} + template + __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; template -struct intrin_mfma_f32_4x4x2bf16; +struct intrin_mfma_f32_32x32x4bf16; template <> -struct intrin_mfma_f32_4x4x2bf16<4, 64> +struct intrin_mfma_f32_32x32x4bf16<32, 32> { - __device__ static c_vec4_1_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) + template + __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); - return reg_c; + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; +template +struct intrin_mfma_f32_16x16x8bf16; + template <> -struct intrin_mfma_f32_4x4x2bf16<8, 64> +struct intrin_mfma_f32_16x16x8bf16<16, 16> { - __device__ static c_vec4_2_t::VecType - run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) + template + __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) { - reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); - reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); - return reg_c; + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -#endif +template +struct intrin_mfma_i32_32x32x8i8; + +template <> +struct intrin_mfma_i32_32x32x8i8<32, 32> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type(reg_a), + as_type(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x16i8; + +template <> +struct intrin_mfma_i32_16x16x16i8<16, 16> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type(reg_a), + as_type(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; } // namespace ck #endif diff --git a/external/rocm/include/bfloat16_dev.hpp b/external/rocm/include/bfloat16_dev.hpp index 52d00346cf..304d8406a8 100644 --- a/external/rocm/include/bfloat16_dev.hpp +++ b/external/rocm/include/bfloat16_dev.hpp @@ -31,7 +31,7 @@ extern "C" { #endif #ifdef __HIP_PLATFORM_HCC__ -#define EXECUTION_SPECIFIER __device__ +#define EXECUTION_SPECIFIER __device__ __host__ #else #define EXECUTION_SPECIFIER #endif // MIOPEN_BACKEND_HIP diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index b52585fb85..7082f1050c 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -325,30 +325,30 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 5: - out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); break; default: - out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); }; wei.GenerateTensorValue(gen_wei, num_thread); } diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 881df7762d..e63f176d4b 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -80,13 +80,29 @@ void host_convolution_forward(const Tensor& in, if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && wi < in.mDesc.GetLengths()[3]) { - v += static_cast(in(n, c, hi, wi)) * - static_cast(wei(k, c, y, x)); + if constexpr(is_same::value) + { + v += bfloat16_to_float(in(n, c, hi, wi)) * + bfloat16_to_float(wei(k, c, y, x)); + } + else + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(wei(k, c, y, x)); + } } } } } - out(n, k, ho, wo) = v; + + if constexpr(is_same::value) + { + out(n, k, ho, wo) = float_to_bfloat16(v); + } + else + { + out(n, k, ho, wo) = v; + } }; auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { @@ -102,13 +118,28 @@ void host_convolution_forward(const Tensor& in, if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && wi < in.mDesc.GetLengths()[2]) { - v += static_cast(in(n, hi, wi, c)) * - static_cast(wei(k, y, x, c)); + if constexpr(is_same::value) + { + v += bfloat16_to_float(in(n, hi, wi, c)) * + bfloat16_to_float(wei(k, y, x, c)); + } + else + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(wei(k, y, x, c)); + } } } } } - out(n, ho, wo, k) = v; + if constexpr(is_same::value) + { + out(n, ho, wo, k) = float_to_bfloat16(v); + } + else + { + out(n, ho, wo, k) = v; + } }; if(layout == ConvTensorLayout::NCHW) @@ -226,10 +257,14 @@ int main(int argc, char* argv[]) using in_data_t = float; using acc_data_t = float; using out_data_t = float; -#elif 1 +#elif 0 using in_data_t = half_t; using acc_data_t = float; using out_data_t = half_t; +#elif 1 + using in_data_t = ushort; + using acc_data_t = float; + using out_data_t = ushort; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; @@ -295,30 +330,30 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); break; default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); }; wei.GenerateTensorValue(gen_wei, num_thread); } diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp index 2d63f0272b..0151fea9e5 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -297,30 +297,30 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 5: - in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); - out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); break; default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); auto gen_out = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); }; out.GenerateTensorValue(gen_out, num_thread); } diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp index be784c01a2..23158b7b66 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -239,10 +239,14 @@ int main(int argc, char* argv[]) using ab_data_t = float; using acc_data_t = float; using c_data_t = float; -#elif 1 +#elif 0 using ab_data_t = half_t; using acc_data_t = float; using c_data_t = half_t; +#elif 1 + using ab_data_t = ushort; + using acc_data_t = float; + using c_data_t = ushort; #elif 1 using ab_data_t = int8_t; using acc_data_t = int32_t; @@ -321,24 +325,24 @@ int main(int argc, char* argv[]) // no initialization break; case 1: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 2: - a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; case 3: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); break; case 4: - a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; default: - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } #if USE_GEMM_XDL_MK_KN_MN diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp index 010091fe1f..70f1c4dfa3 100644 --- a/host/host_tensor/include/host_gemm.hpp +++ b/host/host_tensor/include/host_gemm.hpp @@ -1,6 +1,162 @@ #pragma once #include "host_tensor.hpp" +template <> +void host_gemm(const Tensor& a, + const Tensor& b, + Tensor& c, + const GemmMatrixLayout layout) +{ + if(layout == GemmMatrixLayout::MK_KN_MN) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n)); + } + + c(m, n) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + auto f_mk_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k)); + } + + c(m, n) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + auto f_km_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n)); + } + + c(m, n) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + auto f_km_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k)); + } + + c(m, n) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_KN_NM) + { + auto f_mk_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n)); + } + + c(n, m) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_NM) + { + auto f_mk_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k)); + } + + c(n, m) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_NM) + { + auto f_km_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n)); + } + + c(n, m) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_NM) + { + auto f_km_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k)); + } + + c(n, m) = float_to_bfloat16(v); + }; + + make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} + template void host_gemm_mk_kn_mn(const Tensor& a_m_k, const Tensor& b_k_n, diff --git a/host/host_tensor/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp index cf89423769..853261103c 100644 --- a/host/host_tensor/include/host_tensor.hpp +++ b/host/host_tensor/include/host_tensor.hpp @@ -321,4 +321,41 @@ void check_error(const Tensor& ref, const Tensor& result) std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; } +float bf16_to_f32(ushort src_val) +{ + typedef union + { + ushort x, y; + float f32; + } bf16_f32_t; + + bf16_f32_t v; + v.x = 0; + v.y = src_val; + return v.f32; +} + +template <> +void check_error(const Tensor& ref, const Tensor& result) +{ + float error = 0; + float max_diff = -1; + float ref_value = 0, result_value = 0; + for(int i = 0; i < ref.mData.size(); ++i) + { + error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i])); + float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i])); + if(max_diff < diff) + { + max_diff = diff; + ref_value = bf16_to_f32(ref.mData[i]); + result_value = bf16_to_f32(result.mData[i]); + } + } + + std::cout << "error: " << error << std::endl; + std::cout << "max_diff: " << max_diff << ", ref: " << ref_value << ", res: " << result_value + << std::endl; +} + #endif diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp index b0d53995ed..c7b3fb0fb7 100644 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -4,6 +4,7 @@ #include #include "config.hpp" +template struct GeneratorTensor_1 { int value = 1; @@ -15,6 +16,30 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ushort operator()(Is...) + { + return float_to_bfloat16(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + int8_t operator()(Is...) + { + return value; + } +}; + struct GeneratorTensor_0 { int value = 0; @@ -26,6 +51,7 @@ struct GeneratorTensor_0 } }; +template struct GeneratorTensor_2 { int min_value = 0; @@ -38,6 +64,33 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ushort operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return float_to_bfloat16(tmp); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + int8_t operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + template struct GeneratorTensor_3 { @@ -53,6 +106,39 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ushort operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return float_to_bfloat16(fp32_tmp); + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + int8_t operator()(Is...) + { + int8_t min_tmp = static_cast(min_value); + int8_t max_tmp = static_cast(max_value); + + return (std::rand() % (max_tmp - min_tmp)) + min_tmp; + } +}; + struct GeneratorTensor_Checkboard { template