From 65cfb2a15c773c3fe2117ba64b284b3325c19f76 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 21 Oct 2024 12:26:13 -0700 Subject: [PATCH] format --- .../element/unary_element_wise_operation.hpp | 19 +++++++++++-------- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 1 - .../threadwise_tensor_slice_transfer.hpp | 8 ++++---- .../threadwise_tensor_slice_transfer_v3r1.hpp | 17 ++++++++--------- .../cpu/reference_gemm.hpp | 4 ++-- 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 4b5b572dfe..a6f75279fb 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -55,8 +55,8 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q) #else uint8_t x_u8 = ck::bit_cast(q); - int x_l = (x_u8 & 0x0f); - int x_h = (x_u8 & 0xf0) << 12; + int x_l = (x_u8 & 0x0f); + int x_h = (x_u8 & 0xf0) << 12; const int EX = 0x64006400; @@ -66,7 +66,6 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q) return amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); #endif - } struct PassThroughPack8 @@ -87,12 +86,16 @@ struct PassThroughPack8 vector_type dst; vector_type src{x}; - dst.template AsType()(Number<0>{}) = pki4_to_half2(src.template AsType()[Number<0>{}]); - dst.template AsType()(Number<1>{}) = pki4_to_half2(src.template AsType()[Number<1>{}]); - dst.template AsType()(Number<2>{}) = pki4_to_half2(src.template AsType()[Number<2>{}]); - dst.template AsType()(Number<3>{}) = pki4_to_half2(src.template AsType()[Number<3>{}]); + dst.template AsType()(Number<0>{}) = + pki4_to_half2(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + pki4_to_half2(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + pki4_to_half2(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + pki4_to_half2(src.template AsType()[Number<3>{}]); - y = dst.template AsType()[Number<0>{}]; + y = dst.template AsType()[Number<0>{}]; #endif } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index a782268470..4a7695ed1d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1370,7 +1370,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 c_thread_buf, num_k_block_main_loop); - // shuffle C and write out { static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index cff1628564..6fcf53984a 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1025,8 +1025,7 @@ struct ThreadwiseTensorSliceTransfer_v4 if constexpr(is_same_v, pk_i4_t>) { - static_assert(SrcScalarPerVector % PackedSize == 0, - "pk data N cannot be 1"); + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); } } @@ -1126,8 +1125,9 @@ struct ThreadwiseTensorSliceTransfer_v4 using src_vector_t = typename decltype(src_tmp_vector)::type; - //const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( - //src_desc, src_data_coord); + // const bool is_src_valid = + // coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc, + // src_data_coord); const bool is_src_valid = true; // copy data from src_buf into src_tmp_vector diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 925c391827..47c9c5c6db 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -80,14 +80,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 if constexpr(is_same_v, pk_i4_t>) { static_assert(is_same_v, remove_cvref_t>, - "SrcData != DstData"); + "SrcData != DstData"); - static_assert(SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, - "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"); + static_assert(SrcScalarPerVector_ % PackedSize == 0 && + DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"); - static_assert( - SrcVectorDim == DstVectorDim, - "pk_i4_t does not support transpose"); + static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); } } @@ -446,7 +445,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 else { constexpr auto packed_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; @@ -875,8 +874,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 private: static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; - //static constexpr auto src_oob_thread_scratch_desc_ = - //decltype(GetSrcThreadScratchDescriptor()){}; + // static constexpr auto src_oob_thread_scratch_desc_ = + // decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; using SrcThreadScratch = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 8430ffff23..210437399e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -82,7 +82,7 @@ struct ReferenceGemm : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; - i4 = i4 - 8; + i4 = i4 - 8; v_a = type_convert(i4); } else @@ -103,7 +103,7 @@ struct ReferenceGemm : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; - i4 = i4 - 8; + i4 = i4 - 8; v_b = type_convert(i4); } else