diff --git a/composable_kernel/include/tensor_description/dimension.hpp b/composable_kernel/include/tensor_description/dimension.hpp new file mode 100644 index 0000000000..cd897323a3 --- /dev/null +++ b/composable_kernel/include/tensor_description/dimension.hpp @@ -0,0 +1,28 @@ +#ifndef CK_DIMENSION_HPP +#define CK_DIMENSION_HPP + +#include "common_header.hpp" + +namespace ck { + +template +struct Dimension +{ + __host__ __device__ static constexpr auto GetLength() { return Number{}; } +}; + +template +struct NativeDimension : Dimension +{ + __host__ __device__ static constexpr auto GetStride() { return Number{}; } + + __host__ __device__ static constexpr index_t GetOffset(index_t id) { return id * Stride; } + + __host__ __device__ static constexpr index_t GetOffsetDiff(index_t id_diff) + { + return id_diff * Stride; + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dimension_transform.hpp b/composable_kernel/include/tensor_description/dimension_transform.hpp new file mode 100644 index 0000000000..f8b513ef19 --- /dev/null +++ b/composable_kernel/include/tensor_description/dimension_transform.hpp @@ -0,0 +1,217 @@ +#ifndef CK_DIMENSION_TRANSFORM_HPP +#define CK_DIMENSION_TRANSFORM_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = Array; + +// LowLengths: Sequence<...> +template +struct PassThrough +{ + static constexpr index_t nDim = LowLengths::GetSize(); + + using LowerId = MultiIndex; + using UpperId = LowerId; + + __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetUpperNumOfDimension() + { + return GetLowerNumOfDimension(); + } + + __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); } + + __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up; } + + __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) + { + return id_up_diff; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } +}; + +// LowLengths: Sequence<...> +template +struct Pad +{ + static constexpr index_t nDim = LowLengths::GetSize(); + + using LowerId = MultiIndex; + using UpperId = LowerId; + + __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetUpperNumOfDimension() + { + return GetLowerNumOfDimension(); + } + + __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() + { + return GetLowerLengths() + LeftPads + RightPads; + } + + __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up - LeftPads; } + + __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) + { + return id_up_diff; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } +}; + +// LowLengths: Sequence<...> +template +struct Merge +{ + static constexpr index_t nDimLow = LowLengths::GetSize(); + static constexpr index_t nDimUp = 1; + + using LowerId = MultiIndex; + using UpperId = MultiIndex; + + __host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number{}}; + + __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() + { + return Sequence{}, Number<1>{})>{}; + } + + __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) + { + LowerId id_low; + + // not implemeneted + + return id_low; + } + + // id_low_diff depends on id_low_old, so id_low need to be up-to-date + __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff, LowerId id_low_old) + { + LowerId id_low_diff; + + // not implemeneted + + return id_low_diff; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } +}; + +// UpLengths: Sequence<...> +template +struct Unmerge +{ + static constexpr index_t nDimLow = 1; + static constexpr index_t nDimUp = UpLengths::GetSize(); + + __host__ __device__ constexpr Unmerge() + { + static_assert(LowLength == accumulate_on_sequence( + UpLengths{}, math::multiplies{}, Number<1>{}), + "wrong! UpLengths need to be "); + } + + __host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number{}}; + + __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } + + __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) + { + constexpr auto scans = typename sequence_reverse_inclusive_scan, + 1>::type{}; + + LowerId id_low{0}; + + static_for<0, nDim, 1>{}([&](auto idim) { id_low[0] += id_up[idim] * scans[idim]; }); + + return id_low; + } + + __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) + { + return GetLowerId(id_up_diff); + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } +}; + +// UpLengths: Sequence<...> +// Coefficients: Sequence<...> +// id_low = coefficients[0, ...nDimUp-1] * id_up[0, ...nDimUp-1] + coefficients[nDimUp] +template +struct Embed +{ + static constexpr index_t nDimLow = 1; + static constexpr index_t nDimUp = UpLengths::GetSize(); + + static constexpr auto mCoefficients = Coefficients{}; + + __host__ __device__ constexpr Embed() + { + static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, + "wrong! # of dimensions not consistent"); + + constexpr index_t low_id_max = + Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(), + math::plus{}, + Number<0>{}); + + static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range"); + } + + __host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number{}}; + + __host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence{}; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } + + __host__ __device__ static constexpr auto GetLowerId(UpperId id_up) + { + LowerId id_low{mCoefficients[nDimUp]}; + + static_for<0, nDimUp, 1>{}( + [&](auto idim) { id_low[0] += id_up[idim] * mCoefficients[idim]; }); + + return id_low; + } + + __host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff) + { + LowerId id_low_diff{0}; + + static_for<0, nDimUp, 1>{}( + [&](auto idim) { id_low_diff[0] += id_up_diff[idim] * mCoefficients[idim]; }); + + return id_low_diff; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp new file mode 100644 index 0000000000..62f23b2bbd --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -0,0 +1,104 @@ +#ifndef CK_TENSOR_DESCRIPTOR_HPP +#define CK_TENSOR_DESCRIPTOR_HPP + +#include "common_header.hpp" +#include "dimension.hpp" + +namespace ck { + +template +struct NativeTensorDescriptor +{ + using type = NativeTensorDescriptor; + static constexpr index_t nDim = Lengths::GetSize(); + + using Id = MultiIndex; + + __host__ __device__ static constexpr auto GetNumOfDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } + + __host__ __device__ static constexpr auto GetStrides() { return Strides{}; } + + __host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; } + + __host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; } + + __host__ __device__ static constexpr index_t GetOffset(Id id) + { + // not implemented + } +}; + +// LowerTensorDescriptor +// Transforms: std::tuple +// LowerIds: std::tuple> +// UpperIds: std::tuple> +template +struct TransformedTensorDescriptor +{ + using type = TransformedTensorDescriptor; + static constexpr index_t nDimUp = xxxx; + static constexpr index_t nDimLow = xxx; + + static constexpr index_t nTransform = Transforms::GetSize(); + + using UpperId = MultiIndex; + using LowerId = MultiIndex; + + __host__ __device__ static constexpr TransformedTensorDescriptor() + { + static_assert(nTransform == Transforms::GetSize() && + nTransform == LowDimensionMasks::GetSize() && + nTransform == UpDimensionMasks::GetSize(), + "wrong! # of transformations not the same"); + + // TODO: sanity check: LowDimensionMasks should include all low-dimensions, + // UpDimensionMasks should include all up-dimensions + + // TODO: sanity check: while a up-dimension could be associated with multille + // transformation, + // a low-dimension should be associated with only one transformation + } + + __host__ __device__ static constexpr auto GetNumOfDimension() + { + // not implemented + } + + __host__ __device__ static constexpr auto GetLengths() + { + // not implemented + } + + __host__ __device__ static constexpr auto GetLowerTensorDescriptor() + { + return LowTensorDescriptor{}; + } + + __host__ __device__ static constexpr index_t GetLowerId(UpperId id_up) + { + // not implemented + } + + __host__ __device__ static constexpr index_t GetOffset(UpperId id_up) + { + return GetLowerTensorDescriptor().GetOffset(GetLowerId(id_up)); + } + + __host__ __device__ static constexpr auto AreUpperId2OffsetLinear(); + { + // not implemented + } + + __host__ __device__ static constexpr auto GetIndependentDimensionGroups() + { + // not implemented + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index b70b61a7a0..a1ec782c9a 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -4,6 +4,7 @@ #include "config.hpp" #include "utility.hpp" #include "integral_constant.hpp" +#include "tuple.hpp" #include "math.hpp" #include "vector_type.hpp" #include "Sequence.hpp" diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp new file mode 100644 index 0000000000..1d91f59ba6 --- /dev/null +++ b/composable_kernel/include/utility/tuple.hpp @@ -0,0 +1,68 @@ +#ifndef CK_TUPLE_HPP +#define CK_TUPLE_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +struct tuple : public std::tuple +{ + using type = tuple; + + __host__ __device__ static constexpr index_t GetSize() { return std::tuple_size(tuple{}); } + + template + __host__ __device__ constexpr auto Get(Number) const + { + return std::get(*this); + } + + template + __host__ __device__ constexpr auto operator[](Number) const + { + return Get(Number{}) : + } +}; + +// merge tuple +template +__host__ __device__ constexpr auto merge_tuple(Tuples&&... xs) +{ + return std::tuple_cat(xs...); +}; + +// generate sequence +template +struct tuple_gen_impl +{ + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; + + using type = + typename tuple_merge::type, + typename tuple_gen_impl::type>::type; +}; + +template +struct tuple_gen_impl +{ + static constexpr auto x = F{}(Number{}); + using type = tuple; +}; + +template +struct sequence_gen_impl +{ + using type = Sequence<>; +}; + +template +struct sequence_gen +{ + using type = typename sequence_gen_impl<0, NSize, F>::type; +}; + +} // namespace ck +#endif