From 7a2e49df1921cfcbbcb7d2d24ca5361a8f6c43e2 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 2 Feb 2022 23:13:09 -0600 Subject: [PATCH] Replace llvm Intrinsics with clang buildins (#65) * test mfma builtins * add fp16 buildins * add int8 buildins * add bfl16 buildins * simplify host conv forward * clean * clean [ROCm/composable_kernel commit: 6d92959ad3642754d0f6de85388a922d33651578] --- .../include/utility/amd_xdlops.hpp | 146 +++++------------- .../include/utility/dynamic_buffer.hpp | 10 ++ .../include/driver_gemm_xdlops_v2r3.hpp | 34 ++-- 3 files changed, 69 insertions(+), 121 deletions(-) diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp index dadeb5cac4..e37529a757 100644 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -5,77 +5,6 @@ namespace ck { -// A, B, C, cbsz, abid, blgp -// fp32 -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); - -// fp16 -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); - -// bfp16 -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k( - ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k( - ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k"); - -extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( - ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( - ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( - ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16"); - -extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( - ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16"); - -extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( - ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); - -// int8 -extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8( - int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8"); - -extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8( - int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8"); - -extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8( - int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8"); - -extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8( - int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8"); - -extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8( - int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8"); - // fp32 template struct intrin_mfma_f32_32x32x1f32; @@ -86,9 +15,9 @@ struct intrin_mfma_f32_32x32x1f32<64, 64> template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; @@ -99,7 +28,7 @@ struct intrin_mfma_f32_32x32x1f32<32, 64> template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; @@ -113,7 +42,7 @@ struct intrin_mfma_f32_32x32x2f32<32, 32> template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -127,7 +56,7 @@ struct intrin_mfma_f32_16x16x4f32<16, 16> template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -141,8 +70,7 @@ struct intrin_mfma_f32_16x16x1f32<16, 64> template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; @@ -156,7 +84,7 @@ struct intrin_mfma_f32_4x4x1f32<4, 64> template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; @@ -167,9 +95,9 @@ struct intrin_mfma_f32_4x4x1f32<8, 64> template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; @@ -184,9 +112,9 @@ struct intrin_mfma_f32_32x32x4f16<64, 64> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; @@ -197,7 +125,7 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; @@ -211,7 +139,7 @@ struct intrin_mfma_f32_32x32x8f16<32, 32> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -225,7 +153,7 @@ struct intrin_mfma_f32_16x16x16f16<16, 16> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -239,7 +167,7 @@ struct intrin_mfma_f32_16x16x4f16<16, 64> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; @@ -253,7 +181,7 @@ struct intrin_mfma_f32_4x4x4f16<4, 64> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; @@ -264,9 +192,9 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); - reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; @@ -281,9 +209,8 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> template __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -296,9 +223,8 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> template __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -311,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32> template __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -325,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> template __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; @@ -340,12 +266,12 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) { reg_c.template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}], - 0, - 0, - 0); + __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); } }; @@ -359,12 +285,12 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) { reg_c.template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}], - 0, - 0, - 0); + __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); } }; diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index 7bde23f834..63e3ecabb3 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -169,6 +169,8 @@ struct DynamicBuffer is_same, int8x2_t>::value) || (is_same, int8_t>::value && is_same, int8x4_t>::value) || + (is_same, int8_t>::value && + is_same, int8x8_t>::value) || (is_same, int8x4_t>::value && is_same, int8x4_t>::value) || (is_same, int8x8_t>::value && @@ -202,6 +204,14 @@ struct DynamicBuffer *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } + else if constexpr(is_same, int8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } else if constexpr(is_same, int8x4_t>::value && is_same, int8x4_t>::value) { diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index beb06866bc..3aeb91a004 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -5,6 +5,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" +#include "element_wise_operation.hpp" template {}; constexpr auto I2 = Number<2>{}; + using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + CThreadTransferDstScalarPerVector>; { std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", " @@ -152,6 +150,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, float ave_time = 0; + auto element_op_ = ElementwiseOperation{}; + #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE if(has_main_k0_block_loop) { @@ -162,6 +162,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, remove_reference_t, remove_reference_t, remove_reference_t, + ElementwiseOperation, + ElementwiseOperation, + ElementwiseOperation, remove_reference_t, true>; @@ -176,6 +179,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + element_op_, + element_op_, + element_op_, block_2_ctile_map); } else @@ -187,6 +193,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, remove_reference_t, remove_reference_t, remove_reference_t, + ElementwiseOperation, + ElementwiseOperation, + ElementwiseOperation, remove_reference_t, false>; @@ -201,6 +210,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + element_op_, + element_op_, + element_op_, block_2_ctile_map); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER