mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
adding dimension transformation
This commit is contained in:
@@ -0,0 +1,221 @@
|
||||
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
|
||||
#define CK_MULTI_INDEX_TRANSFORM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
using MultiIndex = Array<index_t, N>;
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = LowerIndex;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
return GetNumOfLowerDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths, class LeftPads, class RightPads>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = LowerIndex;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
return GetNumOfLowerDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return GetLowerLengths() + LeftPads + RightPads;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
return idx_up - LeftPads;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
struct Merge
|
||||
{
|
||||
static constexpr index_t nDimLow = LowLengths::GetSize();
|
||||
static constexpr index_t nDimUp = 1;
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return Sequence<accumulate_on_sequence(
|
||||
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
LowerIndex idx_low;
|
||||
|
||||
// not implemeneted
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff,
|
||||
LowerIndex idx_low_old)
|
||||
{
|
||||
LowerIndex idx_low_diff;
|
||||
|
||||
// not implemeneted
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return false; }
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
template <index_t LowLength, class UpLengths>
|
||||
struct Unmerge
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr Unmerge()
|
||||
{
|
||||
static_assert(LowLength == accumulate_on_sequence(
|
||||
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
|
||||
"wrong! UpLengths need to be ");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
|
||||
math::multiplies<index_t>,
|
||||
1>::type{};
|
||||
|
||||
LowerIndex idx_low{0};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) { idx_low[0] += idx_up[idim] * scans[idim]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
{
|
||||
return GetLowerIndex(idx_up_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <index_t LowLength, class UpLengths, class Coefficients>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
static constexpr auto mCoefficients = Coefficients{};
|
||||
|
||||
__host__ __device__ constexpr Embed()
|
||||
{
|
||||
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
constexpr index_t low_id_max =
|
||||
Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
|
||||
math::plus<index_t>{},
|
||||
Number<0>{});
|
||||
|
||||
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
LowerIndex idx_low{mCoefficients[nDimUp]};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low[0] += idx_up[idim] * mCoefficients[idim]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
{
|
||||
LowerIndex idx_low_diff{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * mCoefficients[idim]; });
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,160 @@
|
||||
#ifndef CK_TENSOR_COORDINATE_V2_HPP
|
||||
#define CK_TENSOR_COORDINATE_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
#include "dimension_transform.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class NativeTensorDesc>
|
||||
struct NativeTensorCoordinate
|
||||
{
|
||||
using type = NativeTensorCoordinate;
|
||||
using tensor_desc_type = NativeTensorDesc;
|
||||
using Index = tensor_desc_type::Index;
|
||||
|
||||
static constexpr index_t nDim = Index::GetSize();
|
||||
|
||||
__host__ __device__ constexpr NativeTensorCoordinate(Index idx)
|
||||
: mOffset{GetTensorDesriptor().GetOffset(idx)}
|
||||
{
|
||||
}
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ constexpr NativeTensorCoordinate(Xs... xs)
|
||||
: NativeTensorCoordinate(Index{xs...})
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ constexpr NativeTensorCoordinate(Sequence<Xs...>)
|
||||
: NativeTensorCoordinate(Index{Xs...})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
|
||||
|
||||
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
|
||||
|
||||
__host__ __device__ type operator+=(Index idx_diff)
|
||||
{
|
||||
mOffset += tensor_desc_type::GetOffsetDiff(idx_diff);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ type operator-=(Index idx_diff)
|
||||
{
|
||||
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(idx_diff);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr type operator+(Index idx_diff) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord += idx_diff;
|
||||
return coord;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr type operator-(Index idx_diff) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord -= idx_diff;
|
||||
return coord;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t mOffset;
|
||||
};
|
||||
|
||||
template <class TransformedTensorDesc>
|
||||
struct TransformedTensorCoordinate
|
||||
{
|
||||
using type = TransformedTensorCoordinate;
|
||||
using tensor_desc_type = TransformedTensorDesc;
|
||||
using Index = tensor_desc_type::UpperIndex;
|
||||
|
||||
using lower_coordinate_type =
|
||||
TensorCoordiante_v2<decltype(GetTensorDescriptor().GetLowerTensorDescriptor())>::type;
|
||||
|
||||
static constexpr index_t nDim = Index::GetSize();
|
||||
|
||||
__host__ __device__ constexpr TransformedTensorCoordinate(Index idx)
|
||||
: mIndex{idx}, mCoordLow{GetTensorDescriptor().GetLowerIndex(idx)}
|
||||
{
|
||||
}
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ constexpr TransformedTensorCoordinate(Xs... xs)
|
||||
: TransformedTensorCoordinate(Index{xs...})
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ constexpr TransformedTensorCoordinate(Sequence<Xs...>)
|
||||
: TransformedTensorCoordinate(Index{Xs...})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
|
||||
|
||||
__host__ __device__ constexpr index_t GetOffset() const { return mCoordLow.GetOffset(); }
|
||||
|
||||
__host__ __device__ constexpr Index GetIndex() const { return mIndex; }
|
||||
|
||||
__host__ __device__ type operator+=(Index idx_up_diff)
|
||||
{
|
||||
// For transformation of multi-index difference, not all transformation functions need to
|
||||
// know the old lower-index or the old upper-index. We pass both of them to the
|
||||
// transformation function. The transformation function itself decides to use them or not.
|
||||
mCoordLow +=
|
||||
tensor_desc_type::GetLowerIndexDiff(idx_up_diff, mIndexUp, mCoordLow.GetIndex());
|
||||
|
||||
// mIndexUp is updated here, but some (or all) of its entries may never be used
|
||||
mIndexUp += idx_up_diff;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr type operator+(Index idx_up_diff) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord += idx_diff;
|
||||
return coord;
|
||||
}
|
||||
|
||||
private:
|
||||
// mIndexUp may be calculated and update, however, the value of some (or all) of its entries may
|
||||
// never be used. Compiler should be able to remove these entries as well as its calculation
|
||||
// as dead code.
|
||||
// TODO: make sure compiler indeed remove these dead code
|
||||
Index mIndexUp;
|
||||
lower_coordinate_type mCoordLow;
|
||||
};
|
||||
|
||||
template <class TensorDesc>
|
||||
struct TensorCoordinate_v2
|
||||
{
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
public:
|
||||
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
|
||||
};
|
||||
}
|
||||
#endif
|
||||
123
composable_kernel/include/tensor_description/tensor_visit.hpp
Normal file
123
composable_kernel/include/tensor_description/tensor_visit.hpp
Normal file
@@ -0,0 +1,123 @@
|
||||
#ifndef CK_TENSOR_VISIT_HPP
|
||||
#define CK_TENSOR_VISIT_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
#include "dimension_transform.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_coordinate_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class TensorDescriptor>
|
||||
struct TensorVisit
|
||||
{
|
||||
using Index = typename TensorDescriptor::Index;
|
||||
using Coordinate = typename TensorCoordinate_v2<TensorDescriptor>::type;
|
||||
|
||||
__host__ __device__ static void Run_v1(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
ford<TensorDescriptor::GetLengths()>{}(
|
||||
[&](auto idx_diff) { index_t offset = (coord_begin + idx_diff).GetOffset(); });
|
||||
}
|
||||
|
||||
__host__ __device__ static void Run_v2(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
ford<TensorDescriptor::GetLengths()>{}([&](auto idx_diff) {
|
||||
index_t offset_diff = coord_begin.GetOffsetDiff(idx_diff);
|
||||
index_t offset = coord_begin.GetOffset() + offset_diff;
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ static void Run_v3(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions();
|
||||
constexpr auto nonlinear_dimensions = TensorDescriptor::GetNonLinearDimensions();
|
||||
|
||||
constexpr auto lengths = TensorDescriptor::GetLengths();
|
||||
|
||||
constexpr auto linear_dimension_lengths_hack =
|
||||
lambda_HackLengths{}(lengths, linear_dimensions);
|
||||
constexpr auto nonlinear_dimension_lengths_hack =
|
||||
lambda_HackLengths{}(lengths, nonlinear_dimensions);
|
||||
|
||||
ford<nonlinear_dimension_lengths_hack>{}([&](auto idx_diff_nonlinear_hack) {
|
||||
// run-time component
|
||||
index_t offset_diff_nonlinear = coord_begin.GetOffsetDiff(idx_diff_nonlinear_hack);
|
||||
|
||||
ford<linear_dimension_lengths_hack>{}([&](auto idx_diff_linear_hack) {
|
||||
// compile-time component
|
||||
index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack);
|
||||
|
||||
index_t offset =
|
||||
coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ static void Run_v4(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions();
|
||||
|
||||
constexpr auto nonlinear_independent_dimension_groups =
|
||||
TensorDescriptor::GetNonLinearIndependentDimensionGroups();
|
||||
|
||||
constexpr auto lengths = TensorDescriptor::GetLengths();
|
||||
|
||||
constexpr auto linear_dimension_lengths = lambda_HackLengths{}(lengths, linear_dimensions);
|
||||
|
||||
// run-time component
|
||||
index_t offset_diff_nonlinear = 0;
|
||||
|
||||
template <index_t NGroup>
|
||||
struct f_recursion
|
||||
{
|
||||
template <index_t IGroup>
|
||||
__host__ __device__ void Run(Number<IGroup>)
|
||||
{
|
||||
constexpr auto nonlinear_independent_dimensions_igroup =
|
||||
nonlinear_independent_dimension_groups.Get(igroup);
|
||||
constexpr auto nonlinear_independent_lengths_igroup =
|
||||
lambda_HackLengths{}(lengths, nonlinear_independent_dimensions_igroup);
|
||||
|
||||
ford<nonlinear_independent_lengths_igroup>{}(
|
||||
[&](auto idx_diff_nonlinear_igroup_hack) {
|
||||
// run-time component
|
||||
offset_diff_nonlinear +=
|
||||
coord_begin.GetOffsetDiff(idx_diff_nonlinear_igroup_hack);
|
||||
|
||||
Run(Number<IGroup + 1>{});
|
||||
});
|
||||
};
|
||||
|
||||
// inner-most work
|
||||
template <>
|
||||
__host__ __device__ void Run(Number<NGroup>)
|
||||
{
|
||||
ford<linear_dimension_lengths>{}([&](auto idx_diff_linear_hack) {
|
||||
// compile-time component
|
||||
index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack);
|
||||
|
||||
index_t offset =
|
||||
coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// run-time component
|
||||
index_t offset_diff_nonlinear = 0;
|
||||
|
||||
f_recursion<nonlinear_independent_dimension_groups.GetSize()>{}.Run();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user