// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include #include #include #include "ck/utility/data_type.hpp" #include "ck/utility/span.hpp" #include "ck/utility/type_convert.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/ranges.hpp" template std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) { bool first = true; for(auto&& v : range) { if(first) first = false; else os << delim; os << v; } return os; } template std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) { bool first = true; for(auto&& v : range) { if(first) first = false; else os << delim; using RangeType = ck::remove_cvref_t; if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { os << ck::type_convert(v); } else if constexpr(std::is_same_v) { const auto packed_floats = ck::type_convert(v); const ck::vector_type vector_of_floats{packed_floats}; os << vector_of_floats.template AsType()[ck::Number<0>{}] << delim << vector_of_floats.template AsType()[ck::Number<1>{}]; } else { os << static_cast(v); } } return os; } template auto call_f_unpack_args_impl(F f, T args, std::index_sequence) { return f(std::get(args)...); } template auto call_f_unpack_args(F f, T args) { constexpr std::size_t N = std::tuple_size{}; return call_f_unpack_args_impl(f, args, std::make_index_sequence{}); } template auto construct_f_unpack_args_impl(T args, std::index_sequence) { return F(std::get(args)...); } template auto construct_f_unpack_args(F, T args) { constexpr std::size_t N = std::tuple_size{}; return construct_f_unpack_args_impl(args, std::make_index_sequence{}); } struct HostTensorDescriptor { HostTensorDescriptor() = default; void CalculateStrides(); template >> HostTensorDescriptor(const std::initializer_list& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } HostTensorDescriptor(const std::initializer_list& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } template , std::size_t> || std::is_convertible_v, ck::long_index_t>>> HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } template && std::is_convertible_v>> HostTensorDescriptor(const std::initializer_list& lens, const std::initializer_list& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } HostTensorDescriptor(const std::initializer_list& lens, const std::initializer_list& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } template , std::size_t> && std::is_convertible_v, std::size_t>) || (std::is_convertible_v, ck::long_index_t> && std::is_convertible_v, ck::long_index_t>)>> HostTensorDescriptor(const Lengths& lens, const Strides& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } std::size_t GetNumOfDimension() const; std::size_t GetElementSize() const; std::size_t GetElementSpaceSize() const; const std::vector& GetLengths() const; const std::vector& GetStrides() const; template std::size_t GetOffsetFromMultiIndex(Is... is) const { assert(sizeof...(Is) == this->GetNumOfDimension()); std::initializer_list iss{static_cast(is)...}; return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } std::size_t GetOffsetFromMultiIndex(std::vector iss) const { return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); private: std::vector mLens; std::vector mStrides; }; template HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a, const New2Old& new2old) { std::vector new_lengths(a.GetNumOfDimension()); std::vector new_strides(a.GetNumOfDimension()); for(std::size_t i = 0; i < a.GetNumOfDimension(); i++) { new_lengths[i] = a.GetLengths()[new2old[i]]; new_strides[i] = a.GetStrides()[new2old[i]]; } return HostTensorDescriptor(new_lengths, new_strides); } struct joinable_thread : std::thread { template joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) { } joinable_thread(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default; ~joinable_thread() { if(this->joinable()) this->join(); } }; template struct ParallelTensorFunctor { F mF; static constexpr std::size_t NDIM = sizeof...(Xs); std::array mLens; std::array mStrides; std::size_t mN1d; ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast(xs)...}) { mStrides.back() = 1; std::partial_sum(mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); mN1d = mStrides[0] * mLens[0]; } std::array GetNdIndices(std::size_t i) const { std::array indices; for(std::size_t idim = 0; idim < NDIM; ++idim) { indices[idim] = i / mStrides[idim]; i -= indices[idim] * mStrides[idim]; } return indices; } void operator()(std::size_t num_thread = 1) const { std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; std::vector threads(num_thread); for(std::size_t it = 0; it < num_thread; ++it) { std::size_t iw_begin = it * work_per_thread; std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); auto f = [=] { for(std::size_t iw = iw_begin; iw < iw_end; ++iw) { call_f_unpack_args(mF, GetNdIndices(iw)); } }; threads[it] = joinable_thread(f); } } }; template auto make_ParallelTensorFunctor(F f, Xs... xs) { return ParallelTensorFunctor(f, xs...); } template struct Tensor { using Descriptor = HostTensorDescriptor; using Data = std::vector; template Tensor(std::initializer_list lens) : mDesc(lens), mData(GetElementSpaceSize()) { } template Tensor(std::initializer_list lens, std::initializer_list strides) : mDesc(lens, strides), mData(GetElementSpaceSize()) { } template Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize()) { } template Tensor(const Lengths& lens, const Strides& strides) : mDesc(lens, strides), mData(GetElementSpaceSize()) { } Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {} template Tensor CopyAsType() const { Tensor ret(mDesc); ck::ranges::transform( mData, ret.mData.begin(), [](auto value) { return ck::type_convert(value); }); return ret; } Tensor() = delete; Tensor(const Tensor&) = default; Tensor(Tensor&&) = default; ~Tensor() = default; Tensor& operator=(const Tensor&) = default; Tensor& operator=(Tensor&&) = default; template explicit Tensor(const Tensor& other) : Tensor(other.template CopyAsType()) { } decltype(auto) GetLengths() const { return mDesc.GetLengths(); } decltype(auto) GetStrides() const { return mDesc.GetStrides(); } std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); } std::size_t GetElementSize() const { return mDesc.GetElementSize(); } std::size_t GetElementSpaceSize() const { if constexpr(ck::is_same_v, ck::pk_i4_t>) { return (mDesc.GetElementSpaceSize() + 1) / 2; } else { return mDesc.GetElementSpaceSize(); } } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } void SetZero() { ck::ranges::fill(mData, T{0}); } template void ForEach_impl(F&& f, std::vector& idx, size_t rank) { if(rank == mDesc.GetNumOfDimension()) { f(*this, idx); return; } // else for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) { idx[rank] = i; ForEach_impl(std::forward(f), idx, rank + 1); } } template void ForEach(F&& f) { std::vector idx(mDesc.GetNumOfDimension(), 0); ForEach_impl(std::forward(f), idx, size_t(0)); } template void ForEach_impl(const F&& f, std::vector& idx, size_t rank) const { if(rank == mDesc.GetNumOfDimension()) { f(*this, idx); return; } // else for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) { idx[rank] = i; ForEach_impl(std::forward(f), idx, rank + 1); } } template void ForEach(const F&& f) const { std::vector idx(mDesc.GetNumOfDimension(), 0); ForEach_impl(std::forward(f), idx, size_t(0)); } template void GenerateTensorValue(G g, std::size_t num_thread = 1) { switch(mDesc.GetNumOfDimension()) { case 1: { auto f = [&](auto i) { (*this)(i) = g(i); }; make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread); break; } case 2: { auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); }; make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread); break; } case 3: { auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); }; make_ParallelTensorFunctor( f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread); break; } case 4: { auto f = [&](auto i0, auto i1, auto i2, auto i3) { (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3); }; make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2], mDesc.GetLengths()[3])(num_thread); break; } case 5: { auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) { (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4); }; make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2], mDesc.GetLengths()[3], mDesc.GetLengths()[4])(num_thread); break; } case 6: { auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) { (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5); }; make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2], mDesc.GetLengths()[3], mDesc.GetLengths()[4], mDesc.GetLengths()[5])(num_thread); break; } case 12: { auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5, auto i6, auto i7, auto i8, auto i9, auto i10, auto i11) { (*this)(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) = g(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11); }; make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2], mDesc.GetLengths()[3], mDesc.GetLengths()[4], mDesc.GetLengths()[5], mDesc.GetLengths()[6], mDesc.GetLengths()[7], mDesc.GetLengths()[8], mDesc.GetLengths()[9], mDesc.GetLengths()[10], mDesc.GetLengths()[11])(num_thread); break; } default: throw std::runtime_error("unspported dimension"); } } template std::size_t GetOffsetFromMultiIndex(Is... is) const { if constexpr(ck::is_same_v, ck::pk_i4_t>) { return mDesc.GetOffsetFromMultiIndex(is...) / 2; } else { return mDesc.GetOffsetFromMultiIndex(is...); } } template T& operator()(Is... is) { if constexpr(ck::is_same_v, ck::pk_i4_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } else { return mData[mDesc.GetOffsetFromMultiIndex(is...)]; } } template const T& operator()(Is... is) const { if constexpr(ck::is_same_v, ck::pk_i4_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } else { return mData[mDesc.GetOffsetFromMultiIndex(is...)]; } } T& operator()(std::vector idx) { if constexpr(ck::is_same_v, ck::pk_i4_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } else { return mData[mDesc.GetOffsetFromMultiIndex(idx)]; } } const T& operator()(std::vector idx) const { if constexpr(ck::is_same_v, ck::pk_i4_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } else { return mData[mDesc.GetOffsetFromMultiIndex(idx)]; } } typename Data::iterator begin() { return mData.begin(); } typename Data::iterator end() { return mData.end(); } typename Data::pointer data() { return mData.data(); } typename Data::const_iterator begin() const { return mData.begin(); } typename Data::const_iterator end() const { return mData.end(); } typename Data::const_pointer data() const { return mData.data(); } typename Data::size_type size() const { return mData.size(); } template auto AsSpan() const { constexpr std::size_t FromSize = sizeof(T); constexpr std::size_t ToSize = sizeof(U); using Element = std::add_const_t>; return ck::span{reinterpret_cast(data()), size() * FromSize / ToSize}; } template auto AsSpan() { constexpr std::size_t FromSize = sizeof(T); constexpr std::size_t ToSize = sizeof(U); using Element = std::remove_reference_t; return ck::span{reinterpret_cast(data()), size() * FromSize / ToSize}; } Descriptor mDesc; Data mData; };