mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
try using more constexpr
This commit is contained in:
@@ -4,7 +4,8 @@
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), mod_conv::multiplies<index_t>{})
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), mod_conv::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
}
|
||||
|
||||
@@ -91,8 +92,10 @@ struct ConstantTensorDescriptor
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
@@ -105,9 +108,43 @@ struct ConstantTensorDescriptor
|
||||
|
||||
return offset;
|
||||
}
|
||||
#else
|
||||
template <index_t NSize>
|
||||
struct GetOffsetFromMultiIndex_impl
|
||||
{
|
||||
Array<index_t, NSize>& multi_id_ref;
|
||||
index_t& offset_ref;
|
||||
|
||||
__host__ __device__ constexpr GetOffsetFromMultiIndex_impl(Array<index_t, NSize>& multi_id,
|
||||
index_t& offset)
|
||||
: multi_id_ref(multi_id), offset_ref(offset)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr bool operator()(Number<IDim>) const
|
||||
{
|
||||
offset_ref += multi_id_ref.Get(Number<IDim>{}) * Type::GetStride(Number<IDim>{});
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(GetOffsetFromMultiIndex_impl<NSize>(multi_id, offset));
|
||||
|
||||
return offset;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
|
||||
}
|
||||
@@ -123,7 +160,8 @@ struct ConstantTensorDescriptor
|
||||
multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
#if 0
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
@@ -141,8 +179,58 @@ struct ConstantTensorDescriptor
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
#else
|
||||
struct GetMultiIndexFrom1dIndex_impl
|
||||
{
|
||||
using DummyStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
|
||||
|
||||
__host__ __device__ static auto
|
||||
index_t& id_ref;
|
||||
Array<index_t, nDim>& multi_id_ref;
|
||||
|
||||
__host__ __device__ constexpr GetMultiIndexFrom1dIndex_impl(index_t& id,
|
||||
Array<index_t, nDim>& multi_id)
|
||||
: id_ref(id), multi_id_ref(multi_id)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr bool operator()(Number<IDim>) const
|
||||
{
|
||||
constexpr index_t stride = DummyStrides::Get(Number<IDim>{});
|
||||
multi_id_ref.Set(Number<IDim>{}, id_ref / stride);
|
||||
id_ref -= multi_id_ref.Get(Number<IDim>{}) * stride;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
constexpr auto dummy_strides = calculate_tensor_strides_packed(GetLengths());
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(GetMultiIndexFrom1dIndex_impl(id, multi_id));
|
||||
|
||||
index_t itmp = id / dummy_strides.Get(Number<nDim - 1>{});
|
||||
|
||||
multi_id.Set(Number<nDim - 1>{}, itmp);
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// return type is Sequence<...>
|
||||
template<index_t Id>
|
||||
__host__ __device__ static constexpr auto GetMultiIndexFrom1dIndex(Number<Id>)
|
||||
{
|
||||
return inclusive_scan_sequence(f_impl, GetStrides(), Number<Id>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
return multi_id;
|
||||
@@ -278,8 +366,8 @@ struct ConstantTensorDescriptor
|
||||
// folded strides
|
||||
constexpr auto fold_strides =
|
||||
Number<unfold_stride>{} *
|
||||
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
|
||||
mod_conv::multiplies<index_t>{});
|
||||
reverse_inclusive_scan_sequence(
|
||||
fold_intervals.PushBack(Number<1>{}), mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
|
||||
|
||||
Reference in New Issue
Block a user