diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp new file mode 100644 index 0000000000..831d4c5833 --- /dev/null +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -0,0 +1,221 @@ +#ifndef CK_MULTI_INDEX_TRANSFORM_HPP +#define CK_MULTI_INDEX_TRANSFORM_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = Array; + +// LowLengths: Sequence<...> +template +struct PassThrough +{ + static constexpr index_t nDim = LowLengths::GetSize(); + + using LowerIndex = MultiIndex; + using UpperIndex = LowerIndex; + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension() + { + return GetNumOfLowerDimension(); + } + + __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); } + + __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; } + + __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + { + return idx_up_diff; + } + + __host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } +}; + +// LowLengths: Sequence<...> +template +struct Pad +{ + static constexpr index_t nDim = LowLengths::GetSize(); + + using LowerIndex = MultiIndex; + using UpperIndex = LowerIndex; + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension() + { + return GetNumOfLowerDimension(); + } + + __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() + { + return GetLowerLengths() + LeftPads + RightPads; + } + + __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + { + return idx_up - LeftPads; + } + + __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + { + return idx_up_diff; + } + + __host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } +}; + +// LowLengths: Sequence<...> +template +struct Merge +{ + static constexpr index_t nDimLow = LowLengths::GetSize(); + static constexpr index_t nDimUp = 1; + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex; + + __host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number{}}; + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() + { + return Sequence{}, Number<1>{})>{}; + } + + __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + { + LowerIndex idx_low; + + // not implemeneted + + return idx_low; + } + + // idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date + __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff, + LowerIndex idx_low_old) + { + LowerIndex idx_low_diff; + + // not implemeneted + + return idx_low_diff; + } + + __host__ __device__ static constexpr bool IsIndexTransformLinear() { return false; } +}; + +// UpLengths: Sequence<...> +template +struct Unmerge +{ + static constexpr index_t nDimLow = 1; + static constexpr index_t nDimUp = UpLengths::GetSize(); + + __host__ __device__ constexpr Unmerge() + { + static_assert(LowLength == accumulate_on_sequence( + UpLengths{}, math::multiplies{}, Number<1>{}), + "wrong! UpLengths need to be "); + } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number{}}; + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } + + __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + { + constexpr auto scans = typename sequence_reverse_inclusive_scan, + 1>::type{}; + + LowerIndex idx_low{0}; + + static_for<0, nDim, 1>{}([&](auto idim) { idx_low[0] += idx_up[idim] * scans[idim]; }); + + return idx_low; + } + + __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + { + return GetLowerIndex(idx_up_diff); + } + + __host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } +}; + +// UpLengths: Sequence<...> +// Coefficients: Sequence<...> +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] +template +struct Embed +{ + static constexpr index_t nDimLow = 1; + static constexpr index_t nDimUp = UpLengths::GetSize(); + + static constexpr auto mCoefficients = Coefficients{}; + + __host__ __device__ constexpr Embed() + { + static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, + "wrong! # of dimensions not consistent"); + + constexpr index_t low_id_max = + Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(), + math::plus{}, + Number<0>{}); + + static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range"); + } + + __host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number{}}; + + __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } + + __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + { + LowerIndex idx_low{mCoefficients[nDimUp]}; + + static_for<0, nDimUp, 1>{}( + [&](auto idim) { idx_low[0] += idx_up[idim] * mCoefficients[idim]; }); + + return idx_low; + } + + __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + { + LowerIndex idx_low_diff{0}; + + static_for<0, nDimUp, 1>{}( + [&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * mCoefficients[idim]; }); + + return idx_low_diff; + } + + __host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp b/composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp new file mode 100644 index 0000000000..7c8f3a390e --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp @@ -0,0 +1,160 @@ +#ifndef CK_TENSOR_COORDINATE_V2_HPP +#define CK_TENSOR_COORDINATE_V2_HPP + +#include "common_header.hpp" +#include "dimension.hpp" +#include "dimension_transform.hpp" +#include "tensor_descriptor.hpp" + +namespace ck { + +template +struct NativeTensorCoordinate +{ + using type = NativeTensorCoordinate; + using tensor_desc_type = NativeTensorDesc; + using Index = tensor_desc_type::Index; + + static constexpr index_t nDim = Index::GetSize(); + + __host__ __device__ constexpr NativeTensorCoordinate(Index idx) + : mOffset{GetTensorDesriptor().GetOffset(idx)} + { + } + + template + __host__ __device__ constexpr NativeTensorCoordinate(Xs... xs) + : NativeTensorCoordinate(Index{xs...}) + { + } + + template + __host__ __device__ constexpr NativeTensorCoordinate(Sequence) + : NativeTensorCoordinate(Index{Xs...}) + { + } + + __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; } + + __host__ __device__ constexpr index_t GetOffset() const { return mOffset; } + + __host__ __device__ type operator+=(Index idx_diff) + { + mOffset += tensor_desc_type::GetOffsetDiff(idx_diff); + + return *this; + } + + __host__ __device__ type operator-=(Index idx_diff) + { + mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(idx_diff); + + return *this; + } + + __host__ __device__ constexpr type operator+(Index idx_diff) const + { + type coord = *this; + coord += idx_diff; + return coord; + } + + __host__ __device__ constexpr type operator-(Index idx_diff) const + { + type coord = *this; + coord -= idx_diff; + return coord; + } + + private: + index_t mOffset; +}; + +template +struct TransformedTensorCoordinate +{ + using type = TransformedTensorCoordinate; + using tensor_desc_type = TransformedTensorDesc; + using Index = tensor_desc_type::UpperIndex; + + using lower_coordinate_type = + TensorCoordiante_v2::type; + + static constexpr index_t nDim = Index::GetSize(); + + __host__ __device__ constexpr TransformedTensorCoordinate(Index idx) + : mIndex{idx}, mCoordLow{GetTensorDescriptor().GetLowerIndex(idx)} + { + } + + template + __host__ __device__ constexpr TransformedTensorCoordinate(Xs... xs) + : TransformedTensorCoordinate(Index{xs...}) + { + } + + template + __host__ __device__ constexpr TransformedTensorCoordinate(Sequence) + : TransformedTensorCoordinate(Index{Xs...}) + { + } + + __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; } + + __host__ __device__ constexpr index_t GetOffset() const { return mCoordLow.GetOffset(); } + + __host__ __device__ constexpr Index GetIndex() const { return mIndex; } + + __host__ __device__ type operator+=(Index idx_up_diff) + { + // For transformation of multi-index difference, not all transformation functions need to + // know the old lower-index or the old upper-index. We pass both of them to the + // transformation function. The transformation function itself decides to use them or not. + mCoordLow += + tensor_desc_type::GetLowerIndexDiff(idx_up_diff, mIndexUp, mCoordLow.GetIndex()); + + // mIndexUp is updated here, but some (or all) of its entries may never be used + mIndexUp += idx_up_diff; + + return *this; + } + + __host__ __device__ constexpr type operator+(Index idx_up_diff) const + { + type coord = *this; + coord += idx_diff; + return coord; + } + + private: + // mIndexUp may be calculated and update, however, the value of some (or all) of its entries may + // never be used. Compiler should be able to remove these entries as well as its calculation + // as dead code. + // TODO: make sure compiler indeed remove these dead code + Index mIndexUp; + lower_coordinate_type mCoordLow; +}; + +template +struct TensorCoordinate_v2 +{ + private: + template + __host__ __device__ static constexpr auto + MakeDummyTensorCoordinate(NativeTensorDescriptor) + { + return NativeTensorCoordinate>(); + } + + template + __host__ __device__ static constexpr auto + MakeDummyTensorCoordinate(TransformedTensorDescriptor) + { + return TransformedTensorCoordinate>(); + } + + public: + using type = decltype(MakeDummyTensorCoordinate(TensorDesc{})); +}; +} +#endif diff --git a/composable_kernel/include/tensor_description/tensor_visit.hpp b/composable_kernel/include/tensor_description/tensor_visit.hpp new file mode 100644 index 0000000000..9aa9597e76 --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_visit.hpp @@ -0,0 +1,123 @@ +#ifndef CK_TENSOR_VISIT_HPP +#define CK_TENSOR_VISIT_HPP + +#include "common_header.hpp" +#include "dimension.hpp" +#include "dimension_transform.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_coordinate_v2.hpp" + +namespace ck { + +template +struct TensorVisit +{ + using Index = typename TensorDescriptor::Index; + using Coordinate = typename TensorCoordinate_v2::type; + + __host__ __device__ static void Run_v1(Index idx_begin) + { + const auto coord_begin = Coordinate(idx_begin); + + ford{}( + [&](auto idx_diff) { index_t offset = (coord_begin + idx_diff).GetOffset(); }); + } + + __host__ __device__ static void Run_v2(Index idx_begin) + { + const auto coord_begin = Coordinate(idx_begin); + + ford{}([&](auto idx_diff) { + index_t offset_diff = coord_begin.GetOffsetDiff(idx_diff); + index_t offset = coord_begin.GetOffset() + offset_diff; + }); + } + + __host__ __device__ static void Run_v3(Index idx_begin) + { + const auto coord_begin = Coordinate(idx_begin); + + constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions(); + constexpr auto nonlinear_dimensions = TensorDescriptor::GetNonLinearDimensions(); + + constexpr auto lengths = TensorDescriptor::GetLengths(); + + constexpr auto linear_dimension_lengths_hack = + lambda_HackLengths{}(lengths, linear_dimensions); + constexpr auto nonlinear_dimension_lengths_hack = + lambda_HackLengths{}(lengths, nonlinear_dimensions); + + ford{}([&](auto idx_diff_nonlinear_hack) { + // run-time component + index_t offset_diff_nonlinear = coord_begin.GetOffsetDiff(idx_diff_nonlinear_hack); + + ford{}([&](auto idx_diff_linear_hack) { + // compile-time component + index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack); + + index_t offset = + coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear; + }); + }); + } + + __host__ __device__ static void Run_v4(Index idx_begin) + { + const auto coord_begin = Coordinate(idx_begin); + + constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions(); + + constexpr auto nonlinear_independent_dimension_groups = + TensorDescriptor::GetNonLinearIndependentDimensionGroups(); + + constexpr auto lengths = TensorDescriptor::GetLengths(); + + constexpr auto linear_dimension_lengths = lambda_HackLengths{}(lengths, linear_dimensions); + + // run-time component + index_t offset_diff_nonlinear = 0; + + template + struct f_recursion + { + template + __host__ __device__ void Run(Number) + { + constexpr auto nonlinear_independent_dimensions_igroup = + nonlinear_independent_dimension_groups.Get(igroup); + constexpr auto nonlinear_independent_lengths_igroup = + lambda_HackLengths{}(lengths, nonlinear_independent_dimensions_igroup); + + ford{}( + [&](auto idx_diff_nonlinear_igroup_hack) { + // run-time component + offset_diff_nonlinear += + coord_begin.GetOffsetDiff(idx_diff_nonlinear_igroup_hack); + + Run(Number{}); + }); + }; + + // inner-most work + template <> + __host__ __device__ void Run(Number) + { + ford{}([&](auto idx_diff_linear_hack) { + // compile-time component + index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack); + + index_t offset = + coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear; + }); + } + }; + + // run-time component + index_t offset_diff_nonlinear = 0; + + f_recursion{}.Run(); + } +}; + +} // namespace ck +#endif