mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
@@ -2,35 +2,35 @@
|
||||
#include "common.hip.hpp"
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <unsigned L0, unsigned L1>
|
||||
template <index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
|
||||
{
|
||||
return Sequence<L1, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5>
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
template <unsigned L0,
|
||||
unsigned L1,
|
||||
unsigned L2,
|
||||
unsigned L3,
|
||||
unsigned L4,
|
||||
unsigned L5,
|
||||
unsigned L6,
|
||||
unsigned L7>
|
||||
template <index_t L0,
|
||||
index_t L1,
|
||||
index_t L2,
|
||||
index_t L3,
|
||||
index_t L4,
|
||||
index_t L5,
|
||||
index_t L6,
|
||||
index_t L7>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
|
||||
{
|
||||
@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <unsigned L0, unsigned L1, unsigned Align>
|
||||
template <index_t L0, index_t L1, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align);
|
||||
constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align);
|
||||
return Sequence<L1_align, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align>
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align);
|
||||
constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align);
|
||||
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr unsigned nDim = Lengths::nDim;
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr index_t nDim = Lengths::nDim;
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
{
|
||||
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; }
|
||||
__host__ __device__ constexpr index_t GetDimension() const { return nDim; }
|
||||
|
||||
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
|
||||
|
||||
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned GetLength(Number<I>) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetLength(Number<I>) const
|
||||
{
|
||||
return Lengths{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned GetStride(Number<I>) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetStride(Number<I>) const
|
||||
{
|
||||
return Strides{}.Get(Number<I>{});
|
||||
}
|
||||
@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
|
||||
struct GetElementSize_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return Type{}.GetLength(idim);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const
|
||||
__host__ __device__ constexpr index_t GetElementSize() const
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct multiply
|
||||
{
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
|
||||
struct GetElementSpace_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
|
||||
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct add
|
||||
{
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ unsigned Get1dIndex(Is... is) const
|
||||
__host__ __device__ index_t Get1dIndex(Is... is) const
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
const auto multi_id = Array<unsigned, nDim>(is...);
|
||||
const auto multi_id = Array<index_t, nDim>(is...);
|
||||
|
||||
unsigned id = 0;
|
||||
index_t id = 0;
|
||||
|
||||
static_loop_n<nDim>{}([&](auto IDim) {
|
||||
constexpr unsigned idim = IDim.Get();
|
||||
constexpr index_t idim = IDim.Get();
|
||||
#if DEVICE_BACKEND_HIP
|
||||
id += __mul24(multi_id[idim], GetStride(IDim));
|
||||
#else
|
||||
id += multi_id[idim] * GetStride(IDim);
|
||||
#endif
|
||||
});
|
||||
|
||||
return id;
|
||||
@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <unsigned IDim, unsigned NVector>
|
||||
template <index_t IDim, index_t NVector>
|
||||
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
|
||||
{
|
||||
assert(false); // not implemented
|
||||
@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, unsigned Align>
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
|
||||
@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
{
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr unsigned ndim = desc.GetDimension();
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr index_t ndim = desc.GetDimension();
|
||||
|
||||
static_assert(ndim >= 2 && ndim <= 8, "wrong!");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user