mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
refactor ConstantTensorDescriptor and functional
This commit is contained in:
@@ -115,46 +115,27 @@ struct ConstantTensorDescriptor
|
||||
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t GetDimension() const { return nDim; }
|
||||
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
|
||||
__host__ __device__ static constexpr Lengths GetLengths() { return Lengths{}; }
|
||||
|
||||
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
|
||||
__host__ __device__ static constexpr Strides GetStrides() { return Strides{}; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetLength(Number<I>) const
|
||||
__host__ __device__ static constexpr index_t GetLength(Number<I>)
|
||||
{
|
||||
return Lengths{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetStride(Number<I>) const
|
||||
__host__ __device__ static constexpr index_t GetStride(Number<I>)
|
||||
{
|
||||
return Strides{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSize_f
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return Type{}.GetLength(idim);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr index_t GetElementSize() const
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct multiply
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
return static_const_reduce_n<nDim>{}(GetElementSize_f{}, multiply{});
|
||||
return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
@@ -168,25 +149,16 @@ struct ConstantTensorDescriptor
|
||||
};
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
|
||||
__host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct add
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
index_t element_space_unaligned =
|
||||
static_const_reduce_n<nDim>{}(GetElementSpace_f{}, add{}) + 1;
|
||||
static_const_reduce_n<nDim>{}(GetElementSpace_f{}, mod_conv::plus<index_t>{}) + 1;
|
||||
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ index_t Get1dIndex(Is... is) const
|
||||
__host__ __device__ static index_t Get1dIndex(Is... is)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
@@ -194,7 +166,7 @@ struct ConstantTensorDescriptor
|
||||
|
||||
index_t id = 0;
|
||||
|
||||
static_loop_n<nDim>{}([&](auto IDim) {
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
#if DEVICE_BACKEND_HIP
|
||||
id += __mul24(multi_id[idim], GetStride(IDim));
|
||||
@@ -206,17 +178,26 @@ struct ConstantTensorDescriptor
|
||||
return id;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Condense() const
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
multi_id[idim] = id / GetStride(IDim);
|
||||
id -= multi_id[idim] * GetStride(IDim);
|
||||
});
|
||||
|
||||
multi_id[nDim - 1] = id / GetStride(Number<nDim - 1>{});
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Condense()
|
||||
{
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t NVector>
|
||||
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
|
||||
{
|
||||
assert(false); // not implemented
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
|
||||
Reference in New Issue
Block a user