From 271236466dc543286c241f5417dfdb54de62cba1 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Sat, 11 Mar 2023 07:04:28 +0800 Subject: [PATCH] [Navi3x] Multiple issue fix (#612) * Change gridwise gemm mD blockwise gemm to naive * RRR Gemm fix * Fix RCR gemm bug * Isolate wmma instructions * Update amd_inline_asm.hpp * Update amd_wmma.hpp * Update amd_wmma.hpp * fix syntax and update Jenkinsfile --------- Co-authored-by: zjing14 Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin [ROCm/composable_kernel commit: 087e3105892dec2fa10d2ef6c3419a4f9fe0b84f] --- Jenkinsfile | 2 +- .../device_gemm_multiple_d_wmma_cshuffle.hpp | 12 ++-- .../gpu/device/impl/device_gemm_wmma.hpp | 12 ++-- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 3 +- include/ck/utility/amd_inline_asm.hpp | 18 +++-- include/ck/utility/amd_wmma.hpp | 69 +++++++++++++++++-- 6 files changed, 90 insertions(+), 26 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index bf50033593..bb0b352d75 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -685,7 +685,7 @@ pipeline { agent{ label rocmnode("navi21") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DCMAKE_CXX_FLAGS=" --offload-arch=gfx1030 --offload-arch=gfx1100 --offload-arch=gfx1101 --offload-arch=gfx1102" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030;gfx1100;gfx1101;gfx1102" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } steps{ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index 44ff068355..b59bb2b303 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -121,16 +121,16 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - else if constexpr(is_same_v) + if constexpr(is_same_v) { return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } }(); const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index d645c03113..9b0ff3f46f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -114,16 +114,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - else if constexpr(is_same_v) + if constexpr(is_same_v) { return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } }(); const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 347f3e5f1f..1e8f8ff9fe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -676,7 +676,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); #endif } @@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, "2"(c2), "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); - c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); - c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); #endif } @@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, // Ranged input operand __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) { +#if defined(__gfx11__) asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +#else + ignore = a; + ignore = b; + ignore = c; +#endif } } // namespace ck diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index a0e79220e0..7598fe667b 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> { // * Inline assembly need to elimate the duplicated data load, compiler won't help you // delete them. - amd_assembly_wmma_f32_16x16x16_f16_w32( - reg_a, reg_b, reg_c.template AsType()(Number<0>{})); - // reg_c.template AsType()(Number<0>{}) = - // __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template - // AsType()[Number<0>{}]); + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( neg_a, @@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( neg_a, @@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } };