diff --git a/CHANGELOG.md b/CHANGELOG.md index c721039523..4e3feed2df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ None None ### Additions -* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126, #1139) ### Changes None diff --git a/docs/wrapper.rst b/docs/wrapper.rst index 79b6c75580..c64c0bf17f 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -89,3 +89,4 @@ Operations ------------------------------------- .. doxygenfile:: copy.hpp +.. doxygenfile:: gemm.hpp diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 39b5c79c67..71c512e136 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -248,6 +248,9 @@ struct Layout using DefaultIdxsTupleType = remove_cvref_t; public: + using LayoutShape = Shape; + using LayoutUnrolledDescriptorType = UnrolledDescriptorType; + /** * \brief Transform descriptor to align to passed indexes. * diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index 7b00fe5500..614dfd758e 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -3,45 +3,18 @@ #pragma once -#include "../utils/tensor_utils.hpp" +#include "ck/wrapper/utils/tensor_utils.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" namespace ck { namespace wrapper { -/** - * \brief Perform generic copy between two tensors partitions (threadwise copy). - * Tensors must have the same size. - * - * \param src_tensor Source tensor. - * \param dst_tensor Destination tensor. - */ -template -__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) -{ - if constexpr(!SrcTensorType::IsDynamicBuffer) - { - using SizeType = decltype(size(src_tensor)); - static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); - } - else if constexpr(!DstTensorType::IsDynamicBuffer) - { - using SizeType = decltype(size(dst_tensor)); - static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); - } - else - { - for(int i = 0; i < size(src_tensor); i++) - { - dst_tensor(i) = src_tensor(i); - } - } -} - /** * \brief Perform optimized copy between two tensors partitions (threadwise copy). * Tensors must have the same size. @@ -167,9 +140,99 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) else { // Perform copy between StaticBuffers - copy(src_tensor, dst_tensor); + static_for<0, SrcShapeType::Size(), 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); } } +/** + * \brief Perform generic copy between two tensors partitions (threadwise copy). + * Tensors must have the same size. + * + * \param src_tensor Source tensor. + * \param dst_tensor Destination tensor. + */ +template +__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) +{ + // Generate default params + using SrcShapeType = remove_cvref_t; + constexpr index_t num_dims = SrcShapeType::Size(); + // Incrementing dims 0, 1, 2 ... num_dims - 1 + constexpr auto dim_access_order_tuple = + generate_tuple([](auto i) { return Number{}; }, Number{}); + constexpr index_t vector_dim = num_dims - 1; + constexpr index_t scalar_per_vector = 1; + copy(src_tensor, dst_tensor); +} + +/** + * \brief Perform optimized blockwise copy between two tensors. Tensors must have the + * same size. + * + * \note At now Vgpr and Sgpr are not supported. + * + * \tparam DimAccessOrderTuple Tuple with dimension access order. + * \tparam VectorDim Dimension for vectorize read and write. + * \tparam ScalarPerVector Number of scalar per vectorize read and write. + * \param src_tensor Source tensor. + * \param dst_tensor Destination tensor. + * \param thread_layout Thread layout per each dimension for copy. + */ +template +__device__ void blockwise_copy(const SrcTensorType& src_tensor, + DstTensorType& dst_tensor, + [[maybe_unused]] ThreadLayoutTuple& thread_layout) +{ + static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer); + static_assert(is_detected::value); + + const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor(); + const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor(); + + using SrcShapeType = remove_cvref_t; + constexpr index_t num_dims = SrcShapeType::Size(); + + constexpr auto tile_lengths_seq = + generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number{}); + constexpr auto thread_layout_seq = generate_sequence_v2( + [](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number{}); + constexpr auto dim_access_order = generate_sequence_v2( + [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number{}); + + using ThisThreadBlock = ThisThreadBlock; + + // Perform copy between DynamicBuffers + auto transfer = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + Tuple, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(InMemoryDataOperationEnum::Set)>, + std::remove_const_t, + std::remove_const_t, + std::remove_const_t, + std::remove_const_t, + VectorDim, + ScalarPerVector, + Sequence, + Sequence>{in_grid_desc, + make_tuple(src_tensor.GetMultiIdxOffsets()), + out_grid_desc, + make_tuple(dst_tensor.GetMultiIdxOffsets()), + tensor_operation::element_wise::PassThrough{}}; + + transfer.Run(tie(in_grid_desc), + tie(src_tensor.GetBuffer()), + tie(out_grid_desc), + tie(dst_tensor.GetBuffer())); +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/operations/gemm.hpp b/include/ck/wrapper/operations/gemm.hpp new file mode 100644 index 0000000000..9b8c0543fd --- /dev/null +++ b/include/ck/wrapper/operations/gemm.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/wrapper/utils/tensor_utils.hpp" +#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" + +namespace ck { +namespace wrapper { + +namespace { +namespace detail { +/** + * \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1). + * + * + * \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension. + * \tparam TileLayout Tensor data tile layout (M,K) or (N,K). + * + * \return Block descriptor (K0, MPerBlock or NPerBlock, K1) + */ +template +__device__ constexpr auto GetBlockDescriptor() +{ + using TileLayoutShape = typename TileLayout::LayoutShape; + using TileLayoutDescriptor = typename TileLayout::LayoutUnrolledDescriptorType; + + constexpr auto K0PerBlock = Number(TileLayoutShape{})>{} / Number{}; + // MPerBlock or NPerBlock + constexpr auto Dim0 = Number(TileLayoutShape{})>{}; + + constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor( + TileLayoutDescriptor{}, + make_tuple(make_unmerge_transform(make_tuple(K0PerBlock, Number{})), + make_pass_through_transform(Dim0)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_block_desc_k0_m_k1; +} + +} // namespace detail +} // namespace + +/** + * \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be + * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B + * data layout must be (NPerBlock, KPerBlock). + * + * \note C output Vgpr register layout (8D): + * - MXdlPerWave - The number of MFMA instructions run by single wave in M + * dimension per tile. + * - NXdlPerWave - The number of MFMA instructions run by single wave in N + * dimension per tile. + * - MWave - Equals to 1 since this is for single wave. + * - NWave - Equals to 1 since this is for single wave. + * - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumInputsBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - GroupSize - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * + * \tparam DataType Input data types. + * \tparam BlockSize Tensor to pad. + * \tparam GemmTraits Traits of gemm xdl operation. + * \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm + * (MPerBlock, KPerBlock) layout. + * \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm + * (NPerBlock, KPerBlock) layout. + * \param c_reg_tensor C tensor VGPR memory for blockwise gemm. + */ +template +__device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, + const BTensorType& b_local_tile_tensor, + CTensorType& c_reg_tensor) +{ + static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); + static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); + static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr); + static_assert(is_same_v); + static_assert(is_same_v); + + constexpr bool is_integer = + is_same_v || is_same_v || is_same_v; + using GemmAccDataType = std::conditional_t; + + using ATileLayout = remove_cvref_t; + using BTileLayout = remove_cvref_t; + + using ABlockDesc_K0_M_K1_Type = + decltype(detail::GetBlockDescriptor()); + using BBlockDesc_K0_N_K1_Type = + decltype(detail::GetBlockDescriptor()); + + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + blockwise_gemm_xdl_op{}; + + blockwise_gemm_xdl_op.Run( + a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer()); +} + +/** + * \brief Create local partition per thread for C tensor. + * + * \note C output global memory layout (8D): + * - MXdlPerWave - The number of MFMA instructions run by single wave in M + * dimension. + * - NXdlPerWave - The number of MFMA instructions run by single wave in N + * dimension. + * - MWave - The number of waves in single tile M dimension per tile. + * - NWave - The number of waves in single tile N dimension per tile. + * - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumInputsBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - GroupSize - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * + * \tparam DataType Input data types. + * \tparam ATileLayout A tensor layout. + * \tparam BTileLayout B tensor layout. + * \tparam BlockSize Number of threads in block. + * \tparam GemmTraits Traits of gemm xdl operation. + * \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm + * (MPerBlock, NPerBlock) layout. + * + * \return Partition c tensor for blockwise gemm. + */ +template +__host__ __device__ constexpr auto +make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + constexpr bool is_integer = + is_same_v || is_same_v || is_same_v; + using GemmAccDataType = std::conditional_t; + + using ABlockDesc_K0_M_K1_Type = + decltype(detail::GetBlockDescriptor()); + using BBlockDesc_K0_N_K1_Type = + decltype(detail::GetBlockDescriptor()); + + using BlockwiseGemmXdlops = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // Calculate offset on grid + const auto c_thread_mtx_on_block = + BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + // Create partition shape based on descriptor dims. + const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1); + + const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( + layout(c_local_tile_tensor).GetUnrolledDescriptor()); + const auto partition_layout = + Layout, decltype(partition_desc)>( + partition_shape, partition_desc); + auto partition_tensor = make_tensor( + c_local_tile_tensor.GetPointer(), partition_layout); + partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])); + return partition_tensor; +} + +/** + * \brief Create local partition per thread for C tensor. + * + * \note C output Vgpr register layout (8D): + * - MXdlPerWave - The number of MFMA instructions run by single wave in M + * dimension per tile. + * - NXdlPerWave - The number of MFMA instructions run by single wave in N + * dimension per tile. + * - MWave - Equals to 1 since this is for single wave. + * - NWave - Equals to 1 since this is for single wave. + * - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumInputsBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - GroupSize - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * + * \tparam DataType Input data types. + * \tparam ATileLayout A tensor layout. + * \tparam BTileLayout B tensor layout. + * \tparam BlockSize Number of threads in block. + * \tparam GemmTraits Traits of gemm xdl operation. + * + * \return Vgpr c tensor for blockwise gemm. + */ +template +__host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + constexpr bool is_integer = + is_same_v || is_same_v || is_same_v; + using GemmAccDataType = std::conditional_t; + + using ABlockDesc_K0_M_K1_Type = + decltype(detail::GetBlockDescriptor()); + using BBlockDesc_K0_N_K1_Type = + decltype(detail::GetBlockDescriptor()); + + using BlockwiseGemmXdlops = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + // Calcualte descriptor, shape and layout + constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0], + vgpr_desc.GetLengths()[I1], + vgpr_desc.GetLengths()[I2], + vgpr_desc.GetLengths()[I3], + vgpr_desc.GetLengths()[I4], + vgpr_desc.GetLengths()[I5], + vgpr_desc.GetLengths()[I6], + vgpr_desc.GetLengths()[I7]); + const auto vgpr_layout = Layout, decltype(vgpr_desc)>( + vgpr_shape, vgpr_desc); + // Get vector type for Vgpr + using BlockwiseGemmCThreadBufferType = + remove_reference_t; + using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V; + return ck::wrapper::make_register_tensor( + vgpr_layout); +} + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 57d79c5940..e344399dbf 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -10,8 +10,8 @@ namespace ck { namespace wrapper { -namespace detail { namespace { +namespace detail { /** * \brief Check if Tuple contains Slice object * @@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); } -} // namespace } // namespace detail +} // namespace /** * \brief Tensor wrapper that performs static and dynamic buffer logic. @@ -209,7 +209,10 @@ struct Tensor public: using ElementSpaceSize = decltype(Layout{ Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer - using TensorElementType = ElementType; // DataType + using TensorElementType = std::conditional_t< + is_scalar_type::value, + ElementType, + typename scalar_type>::type>; // DataType static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr || @@ -280,7 +283,7 @@ struct Tensor * \return Requested value. */ template {}), bool> = false> - __host__ __device__ const ElementType& operator[](const Tuple& idx) const + __host__ __device__ const TensorElementType& operator[](const Tuple& idx) const { if constexpr(IsDynamicBuffer) { @@ -301,13 +304,13 @@ struct Tensor } template {}), bool> = false> - __host__ __device__ const ElementType& operator()(const Tuple& idx) const + __host__ __device__ const TensorElementType& operator()(const Tuple& idx) const { return this->operator[](idx); } template {}), bool> = false> - __host__ __device__ const ElementType& operator()(Idxs... idxs) const + __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const { return this->operator[](make_tuple(idxs...)); } @@ -319,7 +322,7 @@ struct Tensor * \return Requested value. */ template {}), bool> = false> - __host__ __device__ ElementType& operator[](const Tuple& idx) + __host__ __device__ TensorElementType& operator[](const Tuple& idx) { if constexpr(IsDynamicBuffer) { @@ -340,13 +343,13 @@ struct Tensor } template {}), bool> = false> - __host__ __device__ ElementType& operator()(const Tuple& idx) + __host__ __device__ TensorElementType& operator()(const Tuple& idx) { return this->operator[](idx); } template {}), bool> = false> - __host__ __device__ ElementType& operator()(Idxs... idxs) + __host__ __device__ TensorElementType& operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } @@ -366,7 +369,7 @@ struct Tensor * * \return Pointer. */ - __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; } + __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; } __host__ __device__ constexpr auto& GetBuffer() { return buffer_; } __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; } @@ -395,10 +398,18 @@ struct Tensor ElementType, ElementSpaceSize, true /*InvalidElementUseNumericalZeroValue*/>; - using StaticBufferType = StaticBuffer; + using StaticBufferType = std::conditional_t< + is_scalar_type::value, + StaticBuffer, + StaticBufferTupleOfVector>::vector_size, + scalar_type>::vector_size, + true /*InvalidElementUseNumericalZeroValue*/>>; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; diff --git a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp new file mode 100644 index 0000000000..24d863f5b1 --- /dev/null +++ b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { +namespace wrapper { + +/** + * \brief Traits for blockwise gemm xdl. + * + * \tparam MPerXDLValue The MFMA instruction size in M dimension. + * \tparam NPerXDLValue The MFMA instruction size in N dimension. + * \tparam MXdlPerWaveValue The number of MFMA instructions run by single + * wave in M dimension. + * \tparam NXdlPerWaveValue The number of MFMA instructions run by single + * wave in N dimension. + * \tparam K1Value The number of K-dim elements that are packed together as + * a separate logical dimension. Usually aligns with vector load size. + */ +template +struct BlockwisGemmXdlTraits +{ + static constexpr index_t MPerXDL = MPerXDLValue; + static constexpr index_t NPerXDL = NPerXDLValue; + static constexpr index_t MXdlPerWave = MXdlPerWaveValue; + static constexpr index_t NXdlPerWave = NXdlPerWaveValue; + static constexpr index_t K1 = K1Value; +}; + +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> +{ +}; + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index 6aae5a92fe..5638382dba 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -6,6 +6,7 @@ #include "tensor_utils.hpp" #include "layout_utils.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_description/cluster_descriptor.hpp" @@ -14,6 +15,8 @@ namespace wrapper { namespace { +namespace detail { + /** * \brief Calculate shape for partition based on number of threads per each dim and * previous shape @@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple{}; - const auto slice_len = size(shape) / thread_lengths.At(num_i); + const auto slice_len = + ck::math::integer_divide_ceil(size(shape), thread_lengths.At(num_i)); return slice_len; }, Number::Size()>{}); } +/** + * \brief Apply projection. + * + * \param base_tuple Tuple to apply projection. + * \param projection Projection to remove selected dim from partitioning. + * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \return Multi index after projection. + */ +template +__host__ __device__ constexpr auto +ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, + [[maybe_unused]] const ProjectionTuple& projection) +{ + if constexpr(is_same_v>) + { + return Tuple<>{}; + } + else + { + auto base_tuple_after_projection = generate_tuple( + [&](auto i) { + const auto i_num = Number{}; + static_assert( + is_detected>::value || + is_same_v, Number<1>>); + if constexpr(is_detected>::value) + { + // When slice (to remove), then insert empty tuple (will be removed in next + // step). + return Tuple<>{}; + } + else + { + return base_tuple.At(i_num); + } + }, + Number{}); + // Remove empty tuples + return UnrollNestedTuple<0, 1>(base_tuple_after_projection); + } +} + +/** + * \brief Calculate shape with dims from projection. + * + * \param shape Base tensor shape. + * \param projection Projection to remove selected dim from partitioning. + * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \return Shape with dims from projection + */ +template +__host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple& shape, + const Tuple& projection) +{ + return generate_tuple( + [&](auto i) { + if constexpr(is_detected>>::value) + { + return size(projection).to_; + } + else + { + // number of shape element in actual fragment of shape and projection (method to + // calculate shape idx) + constexpr index_t shape_i = + detail::ApplyProjection(TupleSlice<0, i>(Tuple{}), + TupleSlice<0, i>(Tuple{})) + .Size(); + return size(shape); + } + }, + Number::Size()>{}); +} + /** * \brief Calculate total number of blocks. * * \param shape Base tensor shape. * \param tile_shape Tile shape. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Tuple with blocks number. */ -template +template __host__ __device__ constexpr auto CalculateGridSize(const Tuple& shape, - const Tuple& tile_shape) + const Tuple& tile_shape, + const Tuple& projection) { - static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); - return generate_tuple([&](auto i) { return size(shape) / size(tile_shape); }, - Number::Size()>{}); + auto shape_with_projection = CalculateShapeWithProjection(shape, projection); + return generate_tuple( + [&](auto i) { + return ck::math::integer_divide_ceil(size(shape_with_projection), + size(tile_shape)); + }, + Number::Size()>{}); } /** @@ -69,8 +155,75 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, return thread_idxs * partition_lengths_seq + old_offset_idxs; } +/** + * \brief Calculate default projection. + * + * \param tile_shape Tile shape. + * \return Default projection (filled with Number<1>{}). + */ +template +__host__ __device__ constexpr auto +GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) +{ + return generate_tuple([&](auto) { return Number<1>{}; }, Number{}); +} + +} // namespace detail } // namespace +/** + * \brief Create local partition for thread (At now only packed partition + * is supported). + * + * \param tensor Tensor for partition. + * \param thread_lengths Layout of threads (could not be nested). + * \param thread_id Thread index represented as integer. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. + * \return Partition tensor. + */ +template +__host__ __device__ constexpr auto +make_local_partition(TensorType& tensor, + [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, + const index_t thread_id, + const ProjectionTuple& projection) +{ + static_assert(!IsNestedTuple(ThreadLengthsTuple{})); + // Calculate new partition shape + const auto& tensor_shape = shape(tensor); + // Calculate projected thread lengths + constexpr auto projected_thread_lengths = + detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{}); + constexpr auto partition_shape = + detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths); + // Create Thread Cluster Descriptor + constexpr auto partition_shape_seq = + generate_sequence_v2([&](auto I) { return size(partition_shape); }, + Number{}); + constexpr auto thread_lengths_seq = + generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, + Number{}); + constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); + // Calculate thread idxs and offsets + const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); + // Apply projection on thread idxs to remove not needed idxs + const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection); + const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( + projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets()); + // Create new layout and tensor + auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); + const auto partition_layout = + Layout, decltype(unrolled_desc)>( + partition_shape, unrolled_desc); + auto partition_tensor = + make_tensor(tensor.GetPointer(), partition_layout); + // Apply offsets + partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return partition_tensor; +} + /** * \brief Create local partition for thread (At now only packed partition * is supported). @@ -81,37 +234,12 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, * \return Partition tensor. */ template -__host__ __device__ constexpr auto -make_local_partition(TensorType& tensor, - [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, - const index_t thread_id) +__host__ __device__ constexpr auto make_local_partition(TensorType& tensor, + const ThreadLengthsTuple& thread_lengths, + const index_t thread_id) { - static_assert(!IsNestedTuple(ThreadLengthsTuple{})); - // Calculate new partition shape - const auto& tensor_shape = shape(tensor); - constexpr auto partition_shape = - CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{}); - // Create Thread Cluster Descriptor - constexpr auto partition_lengths_seq = generate_sequence_v2( - [&](auto I) { return size(partition_shape); }, Number{}); - constexpr auto thread_lengths_seq = - generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, - Number{}); - constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); - // Calculate thread idxs and offsets - const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); - const auto offset_multi_idxs = - CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets()); - // Create new layout and tensor - auto& flatten_desc = layout(tensor).GetUnrolledDescriptor(); - const auto partition_layout = - Layout, decltype(flatten_desc)>( - partition_shape, flatten_desc); - auto partition_tensor = - make_tensor(tensor.GetPointer(), partition_layout); - // Apply offsets - partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); - return partition_tensor; + const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{}); + return make_local_partition(tensor, thread_lengths, thread_id, projection); } /** @@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor, * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. * \param block_id Block index represented as integer. - + * \param projection Projection to remove selected dim from partitioning. + * slice(X) to remove, where X is dim size, Number<1>{} to keep. * \return Tile tensor. */ -template -__host__ __device__ constexpr auto -make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) +template +__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, + const BlockShapeTuple& tile_shape, + const index_t block_id, + const ProjectionTuple& projection) { static_assert(!IsNestedTuple(BlockShapeTuple{})); + constexpr bool is_default_projection = + is_same_v; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); - if constexpr(BlockShapeTuple::Size() == I2) + // TODO: Enable block_2_tile_map partitioning for non-default projection. + if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection) { // Optimized version for 2d tile shape [MxK] const auto block_2_tile_map = @@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con { // Calculate offsets // Sequence with data to process per block - constexpr auto tile_shape_seq = - generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); }, - Number{}); + constexpr auto projected_tile_shape = + detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{}); + using ProjectedTileShapeTuple = decltype(projected_tile_shape); + constexpr auto projected_tile_shape_seq = + generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); }, + Number{}); // Tuple with number of blocks - const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape); - constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); + const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection); + const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); const auto block_idxs = block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id)); - const auto offset_multi_idxs = - CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets()); + const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection); + const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( + projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets()); // Create new layout and tensor const auto tile_layout = - Layout, decltype(aligned_desc)>(tile_shape, - aligned_desc); + Layout, decltype(aligned_desc)>( + projected_tile_shape, aligned_desc); auto tile_tensor = make_tensor(tensor.GetPointer(), tile_layout); // Apply offsets @@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con } } +/** + * \brief Create local tile for thread block. (At now only packed tile + * is supported). + * + * \note Currently to get the best performance please use 2d shape. + * + * \param tensor Tensor for partition. + * \param tile_shape Shapes of requested tile. + * \param block_id Block index represented as integer. + * \return Tile tensor. + */ +template +__host__ __device__ constexpr auto +make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) +{ + const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{}); + return make_local_tile(tensor, tile_shape, block_id, projection); +} + +/** + * \brief Pad tensor shapes to be adjusted to tile lengths. + * + * + * \param tensor Tensor to pad. + * \param tile_lengths Tile lengths to align tensor shape. + * \return Padded tensor. + */ +template +__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths) +{ + const auto& tensor_shape = shape(tensor); + using TensorShapeType = remove_reference_t; + auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); + // Generate sequence with ones to mark that all dims will be padded + constexpr auto do_pads_seq = + generate_sequence_v2([](auto) { return Number<1>{}; }, Number{}); + // Create descriptor with padding + auto padded_desc = + tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq); + // Generate padded shape + const auto padded_shape = generate_tuple( + [&](auto i) { + const auto& dim = size(tensor_shape); + const auto& tile_length = size(tile_lengths); + return ck::math::integer_divide_ceil(dim, tile_length) * tile_length; + }, + Number{}); + // Create layout and tensor + const auto padded_layout = + Layout(padded_shape, padded_desc); + auto partition_tensor = + make_tensor(tensor.GetPointer(), padded_layout); + partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets()); + return partition_tensor; +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp index 7ec080760a..ee9e438a40 100644 --- a/include/ck/wrapper/utils/tensor_utils.hpp +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" #include "ck/utility/number.hpp" #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" @@ -19,9 +20,9 @@ namespace wrapper { * \brief Memory type, allowed members: * - Generic, * - Global, - * - LDS, - * - SGPR, - * - VGPR, + * - Lds, + * - Sgpr, + * - Vgpr, */ using MemoryTypeEnum = AddressSpaceEnum; @@ -52,12 +53,8 @@ struct Slice __host__ __device__ constexpr auto range(const T& dim) const { if constexpr(is_same_v || is_same_v || - is_same_v) + is_same_v, index_t>) { - if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_))) - { - throw std::runtime_error("Invalid range"); - } if(to_ < 0) { return dim - from_ + to_ + 1; @@ -70,9 +67,10 @@ struct Slice } else { - static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_), + static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} && + (ToType{} < 0 || ToType{} > FromType{}), "Invalid range"); - if constexpr(to_ < 0) + if constexpr(ToType{} < 0) { return dim - from_ + to_ + Number<1>{}; } @@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout& return Tensor(layout); } +/** + * \brief Clear tensor. (Only for Vpgr/Sgpr) + * + * \param tensor Tensor to be cleared. + */ +template +__host__ __device__ void +clear(Tensor& tensor) +{ + static_assert( + !Tensor::IsDynamicBuffer); + return tensor.GetBuffer().Clear(); +} + /** * \brief Get Tensor Layout. * diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 6c3e29ab87..cadc146795 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -6,3 +6,9 @@ add_gtest_executable(test_copy test_copy.cpp) target_link_libraries(test_copy PRIVATE utility) add_gtest_executable(test_partition test_partition.cpp) target_link_libraries(test_partition PRIVATE utility) +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR + GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR + GPU_TARGETS MATCHES "gfx942") + add_gtest_executable(test_gemm test_gemm.cpp) + target_link_libraries(test_gemm PRIVATE utility) +endif() diff --git a/test/wrapper/test_gemm.cpp b/test/wrapper/test_gemm.cpp new file mode 100644 index 0000000000..b26cd5fed6 --- /dev/null +++ b/test/wrapper/test_gemm.cpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" + +template +void CheckResult(const std::vector& a_data, + const std::vector& b_data, + std::vector& c_m_n_device_result, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + Tensor a_m_k(HostTensorDescriptor({M, K})); + Tensor b_k_n(HostTensorDescriptor({K, N}, {1, K})); + Tensor c_m_n_host_result(HostTensorDescriptor({M, N})); + + a_m_k.mData = a_data; + b_k_n.mData = b_data; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData)); +} + +template +__global__ void DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayoutShape thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout); + + auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout)); + auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout)); + auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout)); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const ck::index_t block_idx = static_cast(blockIdx.x); + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_padded_global_tensor, + tile_shape, + block_idx, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice); + auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice); + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_padded_global_tensor_k_slice, + tile_shape, + block_idx, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_padded_global_tensor_k_slice, + tile_shape, + block_idx, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); + + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayoutShape& thread_layout) +{ + // Global memory buffers + DeviceMem a_mem(M * K * sizeof(DataType)); + DeviceMem b_mem(K * N * sizeof(DataType)); + DeviceMem c_mem(M * N * sizeof(DataType)); + + std::vector a_data(M * K); + std::vector b_data(K * N); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_data); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_data); + + a_mem.ToDevice(a_data.data()); + b_mem.ToDevice(b_data.data()); + c_mem.SetZero(); + + const ck::index_t grid_size = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) * + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + launch_and_time_kernel(StreamConfig{nullptr}, + kernel, + dim3(grid_size), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + + std::vector c_data(M * N); + c_mem.FromDevice(c_data.data()); + + CheckResult(a_data, b_data, c_data, M, N, K); +} + +TEST(TestGemm, Float) +{ + using DataType = float; + const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Int8) +{ + using DataType = int8_t; + const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Half) +{ + using DataType = ck::half_t; + const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Float_2x4_4x2_XdlPerWave) +{ + using DataType = float; + const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{}); + const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave); +} diff --git a/test/wrapper/test_partition.cpp b/test/wrapper/test_partition.cpp index cacbfe9d88..8b6d220cd7 100644 --- a/test/wrapper/test_partition.cpp +++ b/test/wrapper/test_partition.cpp @@ -29,17 +29,24 @@ TEST(TestPartition, LocalPartition) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto thread_steps = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); - const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); + const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{}); + const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}); + // 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim) + const auto thread_projection = + ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{}); + constexpr ck::index_t projection_thread_length = ck::Number<4>{}; - for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) + for(ck::index_t thread_id = 0; + thread_id < ck::wrapper::size(thread_layout) / projection_thread_length; + thread_id++) { const auto packed_partition = - ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); + ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_projection); const auto expected_partition_size = - ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); - const auto expected_partition_first_val = thread_id * ck::wrapper::size<0>(thread_steps); + ck::wrapper::size(tensor) / + (ck::wrapper::size(thread_layout) / projection_thread_length); + const auto expected_partition_first_val = thread_id * ck::wrapper::size<1>(thread_steps); const auto expected_partition_second_val = expected_partition_first_val + 1; EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size); EXPECT_EQ(packed_partition(0), expected_partition_first_val); @@ -58,8 +65,12 @@ TEST(TestPartition, LocalTile) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - - const auto block_shape = ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}); + // 4d tile partitioning on 3d shape (calculate tile on 4d tile layout, and then skip last dim) + const auto block_shape = + ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{}); + const auto block_projection = + ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2)); + constexpr ck::index_t projection_block_dim = ck::Number<2>{}; const auto num_blocks = ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), @@ -69,9 +80,10 @@ TEST(TestPartition, LocalTile) for(auto block_idx : block_idxs) { - const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); + const auto packed_tile = + ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection); - const auto expected_tile_size = ck::wrapper::size(block_shape); + const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim; auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * ck::wrapper::size<2>(block_shape) * ck::wrapper::size<2>(strides);