mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
refactor
This commit is contained in:
@@ -57,17 +57,38 @@ struct ConstantTensorDescriptor
|
||||
return Strides{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool AreStridesNonAscending()
|
||||
struct lambda_AreDimensionsContinuous
|
||||
{
|
||||
bool flag = true;
|
||||
bool& is_continuous;
|
||||
|
||||
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
|
||||
constexpr auto IDim_p1 = Number<IDim.Get() + 1>{};
|
||||
__host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_)
|
||||
: is_continuous(is_continuous_)
|
||||
{
|
||||
}
|
||||
|
||||
flag = flag && (GetLength(IDim) >= GetLength(IDim_p1));
|
||||
});
|
||||
template <class X>
|
||||
__host__ __device__ constexpr void operator()(X IDim) const
|
||||
{
|
||||
constexpr auto IDim_p1 = IDim + Number<1>{};
|
||||
|
||||
return flag;
|
||||
is_continuous =
|
||||
is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) &&
|
||||
GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1));
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr bool AreDimensionsContinuous()
|
||||
{
|
||||
bool is_continuous = true;
|
||||
|
||||
static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous));
|
||||
|
||||
return is_continuous;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsPackedTensor()
|
||||
{
|
||||
return AreDimensionsContinuous() && GetStride(Number<nDim - 1>{}) == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@@ -92,40 +113,24 @@ struct ConstantTensorDescriptor
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
#if 0
|
||||
// emulate constexpr lambda
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
struct lambda_GetOffsetFromMultiIndex
|
||||
{
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
Array<index_t, NSize>& multi_id;
|
||||
index_t& offset;
|
||||
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
offset += multi_id[idim] * GetStride(IDim);
|
||||
});
|
||||
|
||||
return offset;
|
||||
}
|
||||
#else
|
||||
template <index_t NSize>
|
||||
struct GetOffsetFromMultiIndex_impl
|
||||
{
|
||||
Array<index_t, NSize>& multi_id_ref;
|
||||
index_t& offset_ref;
|
||||
|
||||
__host__ __device__ constexpr GetOffsetFromMultiIndex_impl(Array<index_t, NSize>& multi_id,
|
||||
index_t& offset)
|
||||
: multi_id_ref(multi_id), offset_ref(offset)
|
||||
__host__
|
||||
__device__ constexpr lambda_GetOffsetFromMultiIndex(Array<index_t, NSize>& multi_id_,
|
||||
index_t& offset_)
|
||||
: multi_id(multi_id_), offset(offset_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr bool operator()(Number<IDim>) const
|
||||
template <class X>
|
||||
__host__ __device__ constexpr void operator()(X IDim) const
|
||||
{
|
||||
offset_ref += multi_id_ref.Get(Number<IDim>{}) * Type::GetStride(Number<IDim>{});
|
||||
return true;
|
||||
offset += multi_id.Get(IDim) * Type::GetStride(IDim);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -137,11 +142,10 @@ struct ConstantTensorDescriptor
|
||||
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(GetOffsetFromMultiIndex_impl<NSize>(multi_id, offset));
|
||||
static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex<NSize>(multi_id, offset));
|
||||
|
||||
return offset;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
@@ -160,47 +164,26 @@ struct ConstantTensorDescriptor
|
||||
multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
// emulate constexpr lambda
|
||||
template <class PackedStrides>
|
||||
struct lambda_GetMultiIndexFrom1dIndex
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
index_t& id;
|
||||
Array<index_t, nDim>& multi_id;
|
||||
|
||||
constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
constexpr index_t stride = dummy_strides.Get(Number<idim>{});
|
||||
multi_id[idim] = id / stride;
|
||||
id -= multi_id[idim] * stride;
|
||||
});
|
||||
|
||||
multi_id[nDim - 1] = id / dummy_strides.Get(Number<nDim - 1>{});
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
#else
|
||||
struct GetMultiIndexFrom1dIndex_impl
|
||||
{
|
||||
using DummyStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
|
||||
|
||||
index_t& id_ref;
|
||||
Array<index_t, nDim>& multi_id_ref;
|
||||
|
||||
__host__ __device__ constexpr GetMultiIndexFrom1dIndex_impl(index_t& id,
|
||||
Array<index_t, nDim>& multi_id)
|
||||
: id_ref(id), multi_id_ref(multi_id)
|
||||
__host__
|
||||
__device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_,
|
||||
Array<index_t, nDim>& multi_id_)
|
||||
: id(id_), multi_id(multi_id_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr bool operator()(Number<IDim>) const
|
||||
template <class X>
|
||||
__host__ __device__ constexpr void operator()(X IDim) const
|
||||
{
|
||||
constexpr index_t stride = DummyStrides::Get(Number<IDim>{});
|
||||
multi_id_ref.Set(Number<IDim>{}, id_ref / stride);
|
||||
id_ref -= multi_id_ref.Get(Number<IDim>{}) * stride;
|
||||
|
||||
return true;
|
||||
constexpr index_t stride = PackedStrides::Get(IDim);
|
||||
multi_id.Set(IDim, id / stride);
|
||||
id -= multi_id[IDim] * stride;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -208,27 +191,15 @@ struct ConstantTensorDescriptor
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
|
||||
using PackedStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(GetMultiIndexFrom1dIndex_impl(id, multi_id));
|
||||
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
|
||||
|
||||
index_t itmp = id / dummy_strides.Get(Number<nDim - 1>{});
|
||||
|
||||
multi_id.Set(Number<nDim - 1>{}, itmp);
|
||||
multi_id.Set(Number<nDim - 1>{}, id / PackedStrides::Get(Number<nDim - 1>{}));
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// return type is Sequence<...>
|
||||
template<index_t Id>
|
||||
__host__ __device__ static constexpr auto GetMultiIndexFrom1dIndex(Number<Id>)
|
||||
{
|
||||
return inclusive_scan_sequence(f_impl, GetStrides(), Number<Id>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
@@ -236,9 +207,10 @@ struct ConstantTensorDescriptor
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
// This function doesn't do carry check on the highest dimension, for performance reason.
|
||||
// It is the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound
|
||||
// on the highest dimension
|
||||
// This function doesn't do carry check on the highest dimension for positive stepping (or
|
||||
// borrow check on the lowest dimension for negative stepping) , for performance reason. It is
|
||||
// the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
|
||||
// highest dimension for positive stepping (or on the lowest dimension for negative stepping)
|
||||
template <bool PositiveDirection>
|
||||
__host__ __device__ static Array<index_t, nDim>
|
||||
UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
|
||||
@@ -262,14 +234,14 @@ struct ConstantTensorDescriptor
|
||||
|
||||
if(carry)
|
||||
{
|
||||
++new_multi_id[idim];
|
||||
++new_multi_id(idim);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(new_multi_id[idim] >= GetLength(IDim))
|
||||
{
|
||||
new_multi_id[idim] -= GetLength(IDim);
|
||||
new_multi_id(idim) -= GetLength(IDim);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
@@ -288,14 +260,14 @@ struct ConstantTensorDescriptor
|
||||
|
||||
if(borrow)
|
||||
{
|
||||
--new_multi_id[idim];
|
||||
--new_multi_id(idim);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(new_multi_id[idim] < GetLength(IDim))
|
||||
{
|
||||
new_multi_id[idim] += GetLength(IDim);
|
||||
new_multi_id(idim) += GetLength(IDim);
|
||||
borrow = true;
|
||||
}
|
||||
});
|
||||
@@ -382,15 +354,7 @@ struct ConstantTensorDescriptor
|
||||
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
|
||||
}
|
||||
|
||||
template <index_t Threashold, index_t Delta>
|
||||
struct f_unfold_impl
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t x) const
|
||||
{
|
||||
return x > Threashold ? x - Delta : x;
|
||||
}
|
||||
};
|
||||
|
||||
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
|
||||
{
|
||||
@@ -398,24 +362,6 @@ struct ConstantTensorDescriptor
|
||||
FirstUnfoldDim <= LastUnfoldDim,
|
||||
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
|
||||
|
||||
#if 0 // cannot compile: compiler complain about constexpr
|
||||
// dimensions to be unfold need to be in descending order (w.r.t. strides), and need to be
|
||||
// packed in memory, otherwise, unfolding is invalid
|
||||
static_for<FirstUnfoldDim, LastUnfoldDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
constexpr auto IDim_p1 = IDim + Number<1>{};
|
||||
|
||||
// check stride
|
||||
static_assert(
|
||||
GetStride(IDim) >= GetStride(IDim_p1),
|
||||
"wrong! dimensions to be unfolded need to be in descending order w.r.t strides");
|
||||
|
||||
// check if packed
|
||||
static_assert(GetStride(IDim_p1) * GetLength(IDim_p1) == GetStride(IDim),
|
||||
"wrong! dimensions to be unfolded need to be packed");
|
||||
});
|
||||
#endif
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{};
|
||||
constexpr auto middle =
|
||||
@@ -423,6 +369,9 @@ struct ConstantTensorDescriptor
|
||||
constexpr auto right =
|
||||
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::SeqType{};
|
||||
|
||||
// dimensions to be unfolded need to be continuous
|
||||
static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");
|
||||
|
||||
// unfolded length, stride
|
||||
constexpr index_t unfold_length = accumulate_on_sequence(
|
||||
GetLengths().Extract(middle), mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
@@ -446,16 +395,16 @@ struct ConstantTensorDescriptor
|
||||
template <class MapNew2Old>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
|
||||
{
|
||||
return ConstantTensorDescriptor<decltype(Lengths{}.ReorderGivenNew2Old(MapNew2Old{})),
|
||||
decltype(Strides{}.ReorderGivenNew2Old(MapNew2Old{}))>{};
|
||||
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
|
||||
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
|
||||
}
|
||||
|
||||
#if 0 // require sequence_sort, which is not implemented yet
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
|
||||
{
|
||||
return ConstantTensorDescriptor<decltype(Lengths{}.ReorderGivenOld2New(MapOld2New{})),
|
||||
decltype(Strides{}.ReorderGivenOld2New(MapOld2New{}))>{}
|
||||
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
|
||||
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user