mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
refactoring ConstantTensorDescriptor
[ROCm/composable_kernel commit: a0584426ff]
This commit is contained in:
@@ -65,8 +65,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr unsigned nDim = Lengths::nDim;
|
||||
using NDimConstant = Number<nDim>;
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
{
|
||||
@@ -91,293 +91,70 @@ struct ConstantTensorDescriptor
|
||||
return Strides{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSize_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
{
|
||||
return Type{}.GetLength(idim);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const
|
||||
{
|
||||
static_assert(nDim >= 2 && nDim <= 8, "nDim");
|
||||
|
||||
if(nDim == 2)
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct multiply
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
return GetLength(I0) * GetLength(I1);
|
||||
}
|
||||
else if(nDim == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2);
|
||||
}
|
||||
else if(nDim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
|
||||
}
|
||||
else if(nDim == 5)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4);
|
||||
}
|
||||
else if(nDim == 6)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5);
|
||||
}
|
||||
else if(nDim == 7)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5) * GetLength(I6);
|
||||
}
|
||||
else if(nDim == 8)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5) * GetLength(I6) * GetLength(I7);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
return static_const_reduce_n<nDim>{}(GetElementSize_f{}, multiply{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSpace_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
{
|
||||
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
|
||||
{
|
||||
static_assert(nDim >= 2 && nDim <= 8, "nDim");
|
||||
|
||||
constexpr unsigned align_size = align.Get();
|
||||
|
||||
if(nDim == 2)
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct add
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + align_size;
|
||||
}
|
||||
else if(nDim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 5)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + align_size;
|
||||
}
|
||||
else if(nDim == 6)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 7)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
(GetLength(I6) - 1) * GetStride(I6) + align_size;
|
||||
}
|
||||
else if(nDim == 8)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
(GetLength(I6) - 1) * GetStride(I6) + (GetLength(I7) - 1) * GetStride(I7) +
|
||||
align_size;
|
||||
}
|
||||
return static_const_reduce_n<nDim>{}(GetElementSpace_f{}, add{}) + align.Get();
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1) const
|
||||
template <class... Is>
|
||||
__host__ __device__ unsigned Get1dIndex(Is... is) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
static_assert(nDim == 2, "nDim is not 2");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1);
|
||||
}
|
||||
const auto multi_id = Array<unsigned, nDim>(is...);
|
||||
|
||||
// this is ugly, only for 3d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
unsigned id = 0;
|
||||
|
||||
static_assert(nDim == 3, "nDim is not 3");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2);
|
||||
}
|
||||
static_loop_n<nDim>{}([&](auto IDim) {
|
||||
constexpr unsigned idim = IDim.Get();
|
||||
id += multi_id[idim] * GetStride(IDim);
|
||||
});
|
||||
|
||||
// this is ugly, only for 4d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(nDim == 4, "nDim is not 4");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
|
||||
}
|
||||
|
||||
// this is ugly, only for 5d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
static_assert(nDim == 5, "nDim is not 5");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4);
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
static_assert(nDim == 6, "nDim is not 6");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5);
|
||||
}
|
||||
|
||||
// this is ugly, only for 7d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0,
|
||||
unsigned i1,
|
||||
unsigned i2,
|
||||
unsigned i3,
|
||||
unsigned i4,
|
||||
unsigned i5,
|
||||
unsigned i6) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
static_assert(nDim == 7, "nDim is not 7");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6);
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0,
|
||||
unsigned i1,
|
||||
unsigned i2,
|
||||
unsigned i3,
|
||||
unsigned i4,
|
||||
unsigned i5,
|
||||
unsigned i6,
|
||||
unsigned i7) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(nDim == 8, "nDim is not 8");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6) + i7 * GetStride(I7);
|
||||
return id;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Condense() const
|
||||
@@ -385,6 +162,12 @@ struct ConstantTensorDescriptor
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <unsigned IDim, unsigned NVector>
|
||||
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
|
||||
{
|
||||
assert(false); // not implemented
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
|
||||
Reference in New Issue
Block a user