mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
more utility code
This commit is contained in:
@@ -47,6 +47,19 @@ template <index_t GridSize,
|
||||
index_t OutThreadCopyDataPerAccess_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
static constexpr auto I9 = Number<9>{};
|
||||
static constexpr auto I10 = Number<10>{};
|
||||
static constexpr auto I11 = Number<11>{};
|
||||
|
||||
#if 0
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
@@ -60,11 +73,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
@@ -487,58 +495,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
#if 0
|
||||
constexpr auto tmp = std::tuple<bool>{};
|
||||
constexpr auto flag = std::get<0>(tmp);
|
||||
#else
|
||||
constexpr auto a = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
|
||||
constexpr auto a = make_tuple(true, Sequence<1>{}, index_t(99));
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("adsas %d\n", a.At(Number<0>{}));
|
||||
print_Sequence("seq", a.At(Number<1>{}));
|
||||
printf("adsas %lu\n", a.At(Number<2>{}));
|
||||
printf("[0] %d\n", a.At(I0));
|
||||
print_Sequence("[1]", a.At(I1));
|
||||
printf("[2] %lu\n", a.At(I2));
|
||||
}
|
||||
|
||||
auto b = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
|
||||
bool flag = true;
|
||||
|
||||
b.At(Number<0>{}) = false;
|
||||
auto b = make_tuple(flag, Sequence<1>{}, 99);
|
||||
|
||||
b.At(I0) = false;
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("adsas %d\n", b.At(Number<0>{}));
|
||||
print_Sequence("seq", b.At(Number<1>{}));
|
||||
printf("adsas %lu\n", b.At(Number<2>{}));
|
||||
printf("[0] %d\n", b.At(I0));
|
||||
print_Sequence("[1]", b.At(I1));
|
||||
printf("[2] %lu\n", b.At(I2));
|
||||
|
||||
printf("flag %d\n", flag);
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("adsas %d\n",
|
||||
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<0>{}));
|
||||
print_Sequence(
|
||||
"seq", Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<1>{}));
|
||||
printf("adsas %d\n",
|
||||
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<2>{}));
|
||||
printf("[0] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I0));
|
||||
print_Sequence("[1]", make_tuple(true, Sequence<1>(), index_t(99)).At(I1));
|
||||
printf("[2] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I2));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
#elif 1
|
||||
// create a native tensor descriptor
|
||||
constexpr auto in_n_c_h_w_global_desc =
|
||||
constexpr auto in_c_h_w_n_global_desc =
|
||||
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto pad_h_w = Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{};
|
||||
constexpr auto pass_c = PassThrough<C>{};
|
||||
constexpr auto pass_n = PassThrough<N>{};
|
||||
|
||||
constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n);
|
||||
constexpr auto lower_dim_groups =
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
|
||||
constexpr auto upper_dim_groups =
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
|
||||
|
||||
constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor(
|
||||
in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups);
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc);
|
||||
}
|
||||
print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc);
|
||||
|
||||
// transform the tensor descriptor once
|
||||
//
|
||||
// calculate the offset of some entry
|
||||
printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4}));
|
||||
|
||||
printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4}));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -178,7 +178,7 @@ struct ConstantTensorDescriptor
|
||||
{
|
||||
constexpr auto IDim = IDim_{};
|
||||
constexpr index_t stride = PackedStrides::Get(IDim);
|
||||
multi_id.Set(IDim, id / stride);
|
||||
multi_id(IDim) = id / stride;
|
||||
id -= multi_id[IDim] * stride;
|
||||
}
|
||||
};
|
||||
@@ -192,7 +192,7 @@ struct ConstantTensorDescriptor
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
|
||||
|
||||
multi_id.Set(Number<nDim - 1>{}, id / PackedStrides::Get(Number<nDim - 1>{}));
|
||||
multi_id(Number<nDim - 1>{}) = id / PackedStrides::Get(Number<nDim - 1>{});
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ struct PassThrough
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths, class LeftPads, class RightPads>
|
||||
template <typename LowLengths, typename LeftPads, typename RightPads>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
@@ -67,7 +67,7 @@ struct Pad
|
||||
|
||||
#if 0
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
template <typename LowLengths>
|
||||
struct Merge
|
||||
{
|
||||
static constexpr index_t nDimLow = LowLengths::GetSize();
|
||||
@@ -113,7 +113,7 @@ struct Merge
|
||||
#endif
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
template <index_t LowLength, class UpLengths>
|
||||
template <index_t LowLength, typename UpLengths>
|
||||
struct Unmerge
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
@@ -161,7 +161,7 @@ struct Unmerge
|
||||
// UpLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <index_t LowLength, class UpLengths, class Coefficients>
|
||||
template <index_t LowLength, typename UpLengths, typename Coefficients>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
|
||||
@@ -7,12 +7,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class... NativeDimensions>
|
||||
template <typename... NativeDimensions>
|
||||
struct NativeTensorDescriptor
|
||||
{
|
||||
using type = NativeTensorDescriptor;
|
||||
static constexpr auto mDimensions = Tuple<NativeDimensions...>{};
|
||||
static constexpr index_t nDim = mDimensions.GetSize();
|
||||
static constexpr index_t nDim = sizeof...(NativeDimensions);
|
||||
static constexpr auto mDimensions = make_tuple(NativeDimensions{}...);
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
@@ -20,7 +20,7 @@ struct NativeTensorDescriptor
|
||||
|
||||
struct lambda_GetLength
|
||||
{
|
||||
template <class IDim>
|
||||
template <typename IDim>
|
||||
__host__ __device__ constexpr auto operator()(IDim) const
|
||||
{
|
||||
return GetLength(IDim{});
|
||||
@@ -34,7 +34,7 @@ struct NativeTensorDescriptor
|
||||
|
||||
struct lambda_GetStride
|
||||
{
|
||||
template <class IDim>
|
||||
template <typename IDim>
|
||||
__host__ __device__ constexpr auto operator()(IDim) const
|
||||
{
|
||||
return GetStride(IDim{});
|
||||
@@ -49,16 +49,16 @@ struct NativeTensorDescriptor
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
|
||||
{
|
||||
return mDimensions.Get(Number<IDim>{}).GetLength();
|
||||
return mDimensions.At(Number<IDim>{}).GetLength();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
|
||||
{
|
||||
return mDimensions.Get(Number<IDim>{}).GetStride();
|
||||
return mDimensions.At(Number<IDim>{}).GetStride();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(Index idx)
|
||||
__host__ __device__ static constexpr index_t GetOffset(const Index& idx)
|
||||
{
|
||||
index_t offset = 0;
|
||||
|
||||
@@ -67,7 +67,7 @@ struct NativeTensorDescriptor
|
||||
return offset;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(Index idx_diff)
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff)
|
||||
{
|
||||
index_t offset_diff = 0;
|
||||
|
||||
@@ -96,28 +96,65 @@ struct NativeTensorDescriptor
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
// LowerTensorDescriptor
|
||||
// Transforms: std::tuple<DimensionTransforms...>
|
||||
// LowerDimensionIds: std::tuple<Sequence<...>>
|
||||
// UpperDimensionIds: std::tuple<Sequence<...>>
|
||||
template <class LowTensorDescriptor, class Transforms, class LowDimensionIds, class UpDimensionIds>
|
||||
// Transforms: Tuple<DimensionTransforms...>
|
||||
// LowerDimensionIds: Tuple<Sequence<...>>
|
||||
// UpperDimensionIds: Tuple<Sequence<...>>
|
||||
template <typename LowTensorDescriptor,
|
||||
typename Transforms,
|
||||
typename LowDimensionIds,
|
||||
typename UpDimensionIds>
|
||||
struct TransformedTensorDescriptor
|
||||
{
|
||||
using type = TransformedTensorDescriptor;
|
||||
static constexpr index_t nDimUp = GetUpperNumOfDimension();
|
||||
static constexpr index_t nDimLow = GetLowerNumOfDimension();
|
||||
using type = TransformedTensorDescriptor;
|
||||
static constexpr index_t nTransform = Transforms::Size();
|
||||
|
||||
static constexpr index_t nTransform = Transforms::GetSize();
|
||||
struct lambda_merge_sequences
|
||||
{
|
||||
template <typename... Seqs>
|
||||
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
|
||||
{
|
||||
return merge_sequences(seqs...);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
|
||||
{
|
||||
// Here, we assume all lower-dimensions are active
|
||||
// TODO: sanity-check all lower-dimension are indeed active
|
||||
|
||||
using duplicated_low_active_dims =
|
||||
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
|
||||
|
||||
using low_active_dims = typename sequence_unique_sort<duplicated_low_active_dims,
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return low_active_dims::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
using duplicated_up_active_dims =
|
||||
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
|
||||
|
||||
using up_active_dims = typename sequence_unique_sort<duplicated_up_active_dims,
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return up_active_dims::Size();
|
||||
}
|
||||
|
||||
static constexpr index_t nDimUp = GetNumOfUpperDimension();
|
||||
static constexpr index_t nDimLow = GetNumOfLowerDimension();
|
||||
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
|
||||
__host__ __device__ static constexpr TransformedTensorDescriptor()
|
||||
__host__ __device__ constexpr TransformedTensorDescriptor()
|
||||
{
|
||||
static_assert(nTransform == Transforms::GetSize() &&
|
||||
nTransform == LowDimensionIds::GetSize() &&
|
||||
nTransform == UpDimensionIds::GetSize(),
|
||||
static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() &&
|
||||
nTransform == UpDimensionIds::Size(),
|
||||
"wrong! # of transformations not the same");
|
||||
|
||||
// TODO: sanity check: LowDimensionIds should include all low-dimensions,
|
||||
@@ -128,33 +165,17 @@ struct TransformedTensorDescriptor
|
||||
// a low-dimension should be associated with only one transformation
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
|
||||
{
|
||||
// Here, we assume all lower-dimensions are active
|
||||
// TODO: sanity-check all lower-dimension are indeed active
|
||||
constexpr auto low_active_dims = unique_sort_sequence(
|
||||
merge_tuple_of_sequences(LowDimensionIds{}), math::less<index_t>{});
|
||||
|
||||
return low_active_dims.GetSize();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
constexpr auto up_active_dims =
|
||||
unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less<index_t>{});
|
||||
return up_active_dims.GetSize();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension()
|
||||
{
|
||||
return GetNumOfUpperDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
struct lambda_get_upper_lengths
|
||||
{
|
||||
template <class Transform>
|
||||
template <typename Transform>
|
||||
__host__ __device__ constexpr auto operator()(Transform tran) const
|
||||
{
|
||||
return tran.GetUpperLengths();
|
||||
@@ -173,6 +194,7 @@ struct TransformedTensorDescriptor
|
||||
|
||||
using sort_dimension_ids =
|
||||
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
|
||||
|
||||
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
|
||||
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
|
||||
|
||||
@@ -182,46 +204,48 @@ struct TransformedTensorDescriptor
|
||||
return sorted_upper_lengths;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
|
||||
{
|
||||
return LowTensorDescriptor{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetLowerIndex(UpperIndex idx_up)
|
||||
__host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms::Get(itran);
|
||||
constexpr auto tran = Transforms{}.At(itran);
|
||||
|
||||
constexpr auto idx_low_part = pick_array_element(idx_low, LowDimensionIds::Get(itran));
|
||||
constexpr auto idx_up_part = pick_array_element(idx_up, UpDimensionIds::Get(itran));
|
||||
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
|
||||
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
|
||||
|
||||
// this assume each lower (single) index is only assocaited with one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
// of TransformedTensorDescriptor
|
||||
idx_low_part = tran.GetLowerIndex(idx_up_part);
|
||||
idx_low_part = tran.GetLowerIndex(to_array(idx_up_part));
|
||||
});
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetLowerIndexDiff(UpperIndex idx_up_diff,
|
||||
LowerIndex idx_low_old)
|
||||
__host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const LowerIndex& idx_low_old)
|
||||
{
|
||||
LowerIndex idx_low_diff;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms::Get(itran);
|
||||
constexpr auto tran = Transforms::At(itran);
|
||||
|
||||
constexpr auto idx_up_diff_part =
|
||||
pick_array_element(idx_up_diff, UpDimensionIds::Get(itran));
|
||||
const auto idx_up_diff_part =
|
||||
pick_array_element(idx_up_diff, UpDimensionIds::At(itran));
|
||||
|
||||
constexpr auto idx_low_diff_part =
|
||||
pick_array_element(idx_low_diff, LowDimensionIds::Get(itran));
|
||||
auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds::At(itran));
|
||||
|
||||
constexpr auto idx_low_old_part =
|
||||
pick_array_element(idx_low_old, LowDimensionIds::Get(itran));
|
||||
const auto idx_low_old_part =
|
||||
pick_array_element(idx_low_old, LowDimensionIds::At(itran));
|
||||
|
||||
// this assume each lower (single) index is associated with only one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
@@ -232,13 +256,14 @@ struct TransformedTensorDescriptor
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up)
|
||||
__host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up)
|
||||
{
|
||||
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>);
|
||||
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
@@ -257,8 +282,8 @@ struct TransformedTensorDescriptor
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
};
|
||||
#endif
|
||||
};
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>,
|
||||
@@ -267,15 +292,28 @@ __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths.
|
||||
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
|
||||
}
|
||||
|
||||
template <class Lengths>
|
||||
template <typename Lengths>
|
||||
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths)
|
||||
{
|
||||
constexpr index_t strides = reverse_inclusive_scan_sequence(
|
||||
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
constexpr auto strides = reverse_inclusive_scan_sequence(
|
||||
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
return make_NativeTensorDescriptor(Lengths{}, strides);
|
||||
}
|
||||
|
||||
template <typename LowTensorDescriptor,
|
||||
typename Transforms,
|
||||
typename LowDimensionIds,
|
||||
typename UpDimensionIds>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowTensorDescriptor,
|
||||
Transforms,
|
||||
LowDimensionIds,
|
||||
UpDimensionIds>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class... NativeDimensions>
|
||||
template <typename... NativeDimensions>
|
||||
__host__ __device__ void print_tensor_descriptor(const char* s,
|
||||
NativeTensorDescriptor<NativeDimensions...> desc)
|
||||
{
|
||||
|
||||
@@ -6,48 +6,78 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
template <typename TData, index_t NSize>
|
||||
struct Array
|
||||
{
|
||||
using Type = Array<TData, NSize>;
|
||||
using type = Array<TData, NSize>;
|
||||
using data_type = TData;
|
||||
|
||||
static constexpr index_t nSize = NSize;
|
||||
index_t mData[NSize];
|
||||
|
||||
index_t mData[nSize];
|
||||
__host__ __device__ explicit constexpr Array() {}
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
|
||||
template <typename X, typename... Xs>
|
||||
__host__ __device__ explicit constexpr Array(X x, Xs... xs)
|
||||
: mData{static_cast<TData>(x), static_cast<TData>(xs)...}
|
||||
{
|
||||
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return NSize; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData operator[](Number<I>) const
|
||||
#if 0
|
||||
template <typename T>
|
||||
__host__ __device__ explicit constexpr Array(const T& x)
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
static_assert(T::Size() == NSize, "wrong! size");
|
||||
|
||||
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
|
||||
static_for<0, NSize, 1>{}([&](auto i){
|
||||
mData[i] = x.At(i);
|
||||
})
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return NSize; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return Size(); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ TData& operator()(Number<I>)
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void Set(Number<I>, TData x)
|
||||
__host__ __device__ constexpr const TData& At(Number<I>) const
|
||||
{
|
||||
static_assert(I < NSize, "wrong!");
|
||||
|
||||
mData[I] = x;
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData& At(Number<I>)
|
||||
{
|
||||
static_assert(I < NSize, "wrong!");
|
||||
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
|
||||
|
||||
__host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
|
||||
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr const TData& operator[](I i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr TData& operator()(I i)
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr type& operator=(const T& x)
|
||||
{
|
||||
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = x[i]; });
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
struct lambda_PushBack // emulate constexpr lambda
|
||||
{
|
||||
@@ -63,7 +93,7 @@ struct Array
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void operator()(Number<I>) const
|
||||
{
|
||||
new_array.Set(Number<I>{}, old_array[I]);
|
||||
new_array(Number<I>{}) = old_array[I];
|
||||
}
|
||||
};
|
||||
|
||||
@@ -73,71 +103,98 @@ struct Array
|
||||
|
||||
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
|
||||
|
||||
new_array.Set(Number<NSize>{}, x);
|
||||
new_array(Number<NSize>{}) = x;
|
||||
|
||||
return new_array;
|
||||
}
|
||||
};
|
||||
|
||||
// A: Array
|
||||
// Arr: Array
|
||||
// Picks: Sequence<...>
|
||||
template <class Arr, class Picks>
|
||||
template <typename Arr, typename Picks>
|
||||
struct ArrayElementPicker
|
||||
{
|
||||
using type = ArrayElementPicker;
|
||||
using data_type = typename Arr::data_type;
|
||||
|
||||
__host__ __device__ constexpr ArrayElementPicker(Arr& array) : mData{array}
|
||||
__host__ __device__ constexpr ArrayElementPicker() = delete;
|
||||
|
||||
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
|
||||
{
|
||||
constexpr index_t imax =
|
||||
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
|
||||
|
||||
static_assert(imax < Picks::GetSize(), "wrong! exceeding max id");
|
||||
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); }
|
||||
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr data_type operator[](Number<I>) const
|
||||
__host__ __device__ constexpr const data_type& At(Number<I>) const
|
||||
{
|
||||
constexpr auto IP = Picks::Get(Number<I>{});
|
||||
return mData[IP];
|
||||
}
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
__host__ __device__ constexpr data_type operator[](index_t i) const
|
||||
{
|
||||
constexpr index_t ip = Picks{}[i];
|
||||
return mData[ip];
|
||||
constexpr auto IP = Picks{}[I];
|
||||
return mArray[IP];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ data_type& operator()(Number<I>)
|
||||
__host__ __device__ constexpr data_type& At(Number<I>)
|
||||
{
|
||||
constexpr auto IP = Picks::Get(Number<I>{});
|
||||
return mData[IP];
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
constexpr auto IP = Picks{}[I];
|
||||
return mArray(IP);
|
||||
}
|
||||
|
||||
__host__ __device__ data_type& operator()(index_t i)
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr const data_type& operator[](I i) const
|
||||
{
|
||||
constexpr index_t ip = Picks{}[i];
|
||||
return mData[ip];
|
||||
return At(i);
|
||||
}
|
||||
|
||||
Arr& mData;
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr data_type& operator()(I i)
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr type& operator=(const T& a)
|
||||
{
|
||||
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
Arr& mArray;
|
||||
};
|
||||
|
||||
template <class Arr, class Picks>
|
||||
template <typename Arr, typename Picks>
|
||||
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
|
||||
{
|
||||
return ArrayElementPicker<Arr, Picks>(a);
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto to_array(const T& x)
|
||||
{
|
||||
Array<typename T::data_type, T::Size()> y;
|
||||
|
||||
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
|
||||
|
||||
return y;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
|
||||
{
|
||||
return Array<index_t, sizeof...(Is)>{Is...};
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
template <typename TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto make_zero_array()
|
||||
{
|
||||
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
|
||||
@@ -145,7 +202,7 @@ __host__ __device__ constexpr auto make_zero_array()
|
||||
return zero_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
@@ -156,7 +213,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
|
||||
return Array<TData, NSize>{old_array[IRs]...};
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class MapOld2New>
|
||||
template <typename TData, index_t NSize, typename MapOld2New>
|
||||
struct lambda_reorder_array_given_old2new
|
||||
{
|
||||
const Array<TData, NSize>& old_array;
|
||||
@@ -173,13 +230,13 @@ struct lambda_reorder_array_given_old2new
|
||||
{
|
||||
TData old_data = old_array[IOldDim];
|
||||
|
||||
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
|
||||
constexpr index_t INewDim = MapOld2New::At(Number<IOldDim>{});
|
||||
|
||||
new_array.Set(Number<INewDim>{}, old_data);
|
||||
new_array(Number<INewDim>{}) = old_data;
|
||||
}
|
||||
};
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> /*old2new*/)
|
||||
{
|
||||
@@ -195,7 +252,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class ExtractSeq>
|
||||
template <typename TData, index_t NSize, typename ExtractSeq>
|
||||
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
|
||||
{
|
||||
Array<TData, ExtractSeq::GetSize()> new_array;
|
||||
@@ -204,12 +261,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
|
||||
|
||||
static_assert(new_size <= NSize, "wrong! too many extract");
|
||||
|
||||
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
|
||||
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; });
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
|
||||
template <typename F, typename X, typename Y, typename Z> // emulate constepxr lambda for array
|
||||
// math
|
||||
struct lambda_array_math
|
||||
{
|
||||
const F& f;
|
||||
@@ -226,13 +284,12 @@ struct lambda_array_math
|
||||
__host__ __device__ constexpr void operator()(Number<IDim_>) const
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
z.Set(IDim, f(x[IDim], y[IDim]));
|
||||
z(IDim) = f(x[IDim], y[IDim]);
|
||||
}
|
||||
};
|
||||
|
||||
// Array = Array + Array
|
||||
template <class TData, index_t NSize>
|
||||
template <typename TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
@@ -247,7 +304,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
|
||||
}
|
||||
|
||||
// Array = Array - Array
|
||||
template <class TData, index_t NSize>
|
||||
template <typename TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
@@ -262,7 +319,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
|
||||
}
|
||||
|
||||
// Array += Array
|
||||
template <class TData, index_t NSize>
|
||||
template <typename TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b)
|
||||
{
|
||||
a = a + b;
|
||||
@@ -270,14 +327,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
|
||||
}
|
||||
|
||||
// Array -= Array
|
||||
template <class TData, index_t NSize>
|
||||
template <typename TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b)
|
||||
{
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
// Array = Array + Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
template <typename TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
@@ -294,7 +351,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
|
||||
}
|
||||
|
||||
// Array = Array - Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
template <typename TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
@@ -311,7 +368,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
|
||||
}
|
||||
|
||||
// Array = Array * Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
template <typename TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
@@ -328,7 +385,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
|
||||
}
|
||||
|
||||
// Array = Sequence - Array
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
template <typename TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
@@ -344,7 +401,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class Reduce>
|
||||
template <typename TData, index_t NSize, typename Reduce>
|
||||
__host__ __device__ constexpr TData
|
||||
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
|
||||
{
|
||||
@@ -357,89 +414,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T, index_t NSize>
|
||||
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
|
||||
{
|
||||
constexpr index_t nsize = a.GetSize();
|
||||
|
||||
static_assert(nsize > 0 && nsize <= 10, "wrong!");
|
||||
|
||||
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
|
||||
|
||||
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
|
||||
|
||||
static_if<nsize == 3>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
|
||||
|
||||
static_if<nsize == 4>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
|
||||
|
||||
static_if<nsize == 5>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
|
||||
});
|
||||
|
||||
static_if<nsize == 6>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
|
||||
});
|
||||
|
||||
static_if<nsize == 7>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6]);
|
||||
});
|
||||
|
||||
static_if<nsize == 8>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7]);
|
||||
});
|
||||
|
||||
static_if<nsize == 9>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8]);
|
||||
});
|
||||
|
||||
static_if<nsize == 10>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8],
|
||||
a[9]);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -12,22 +12,22 @@ struct static_for;
|
||||
template <index_t...>
|
||||
struct Sequence;
|
||||
|
||||
template <class Seq, index_t I>
|
||||
template <typename Seq, index_t I>
|
||||
struct sequence_split;
|
||||
|
||||
template <class>
|
||||
template <typename>
|
||||
struct sequence_reverse;
|
||||
|
||||
template <class>
|
||||
template <typename>
|
||||
struct sequence_map_inverse;
|
||||
|
||||
template <class>
|
||||
template <typename>
|
||||
struct is_valid_sequence_map;
|
||||
|
||||
template <index_t I, index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
|
||||
|
||||
template <class Seq>
|
||||
template <typename Seq>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Seq);
|
||||
|
||||
template <index_t... Is>
|
||||
@@ -38,9 +38,11 @@ struct Sequence
|
||||
|
||||
static constexpr index_t mSize = sizeof...(Is);
|
||||
|
||||
__host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
|
||||
__host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetImpl(index_t I)
|
||||
__host__ __device__ static constexpr auto GetSize() { return Size(); }
|
||||
|
||||
__host__ __device__ static constexpr index_t At(index_t I)
|
||||
{
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
@@ -48,23 +50,24 @@ struct Sequence
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto Get(Number<I>)
|
||||
__host__ __device__ static constexpr auto At(Number<I>)
|
||||
{
|
||||
static_assert(I < mSize, "wrong! I too large");
|
||||
|
||||
return Number<GetImpl(Number<I>{})>{};
|
||||
return Number<At(I)>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I>) const
|
||||
__host__ __device__ static constexpr auto Get(Number<I>)
|
||||
{
|
||||
return Get(Number<I>{});
|
||||
return At(Number<I>{});
|
||||
}
|
||||
|
||||
// make sure I is constepxr if you want a constexpr return type
|
||||
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr auto operator[](I i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
|
||||
@@ -74,14 +77,14 @@ struct Sequence
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return Sequence<Type::Get(Number<IRs>{})...>{};
|
||||
return Sequence<Type::At(Number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
// MapOld2New is Sequence<...>
|
||||
template <class MapOld2New>
|
||||
template <typename MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
|
||||
{
|
||||
static_assert(MapOld2New::GetSize() == GetSize(),
|
||||
static_assert(MapOld2New::Size() == Size(),
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
|
||||
@@ -97,13 +100,13 @@ struct Sequence
|
||||
__host__ __device__ static constexpr auto Front()
|
||||
{
|
||||
static_assert(mSize > 0, "wrong!");
|
||||
return Get(Number<0>{});
|
||||
return At(Number<0>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Back()
|
||||
{
|
||||
static_assert(mSize > 0, "wrong!");
|
||||
return Get(Number<mSize - 1>{});
|
||||
return At(Number<mSize - 1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
|
||||
@@ -137,19 +140,19 @@ struct Sequence
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Number<Ns>...)
|
||||
{
|
||||
return Sequence<Type::Get(Number<Ns>{})...>{};
|
||||
return Sequence<Type::At(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
|
||||
{
|
||||
return Sequence<Type::Get(Number<Ns>{})...>{};
|
||||
return Sequence<Type::At(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
|
||||
{
|
||||
static_assert(I < GetSize(), "wrong!");
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
using seq_split = sequence_split<Type, I>;
|
||||
constexpr auto seq_left = typename seq_split::SeqType0{};
|
||||
@@ -158,7 +161,7 @@ struct Sequence
|
||||
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
|
||||
}
|
||||
|
||||
template <class F>
|
||||
template <typename F>
|
||||
__host__ __device__ static constexpr auto Transform(F f)
|
||||
{
|
||||
return Sequence<f(Is)...>{};
|
||||
@@ -166,8 +169,11 @@ struct Sequence
|
||||
};
|
||||
|
||||
// merge sequence
|
||||
template <class, class>
|
||||
struct sequence_merge;
|
||||
template <typename Seq, typename... Seqs>
|
||||
struct sequence_merge
|
||||
{
|
||||
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
@@ -175,8 +181,14 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
using type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename Seq>
|
||||
struct sequence_merge<Seq>
|
||||
{
|
||||
using type = Seq;
|
||||
};
|
||||
|
||||
// generate sequence
|
||||
template <index_t IBegin, index_t NRemain, class F>
|
||||
template <index_t IBegin, index_t NRemain, typename F>
|
||||
struct sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
@@ -188,20 +200,20 @@ struct sequence_gen_impl
|
||||
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I, class F>
|
||||
template <index_t I, typename F>
|
||||
struct sequence_gen_impl<I, 1, F>
|
||||
{
|
||||
static constexpr index_t Is = F{}(Number<I>{});
|
||||
using type = Sequence<Is>;
|
||||
};
|
||||
|
||||
template <index_t I, class F>
|
||||
template <index_t I, typename F>
|
||||
struct sequence_gen_impl<I, 0, F>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
template <index_t NSize, class F>
|
||||
template <index_t NSize, typename F>
|
||||
struct sequence_gen
|
||||
{
|
||||
using type = typename sequence_gen_impl<0, NSize, F>::type;
|
||||
@@ -235,10 +247,10 @@ struct uniform_sequence_gen
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
template <class, class, index_t>
|
||||
template <typename, typename, index_t>
|
||||
struct sequence_reverse_inclusive_scan;
|
||||
|
||||
template <index_t I, index_t... Is, class Reduce, index_t Init>
|
||||
template <index_t I, index_t... Is, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
|
||||
{
|
||||
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
|
||||
@@ -248,23 +260,23 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
|
||||
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
|
||||
};
|
||||
|
||||
template <index_t I, class Reduce, index_t Init>
|
||||
template <index_t I, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
|
||||
{
|
||||
using type = Sequence<Reduce{}(I, Init)>;
|
||||
};
|
||||
|
||||
template <class Reduce, index_t Init>
|
||||
template <typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// split sequence
|
||||
template <class Seq, index_t I>
|
||||
template <typename Seq, index_t I>
|
||||
struct sequence_split
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
static constexpr index_t NSize = Seq{}.Size();
|
||||
|
||||
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
|
||||
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
|
||||
@@ -274,10 +286,10 @@ struct sequence_split
|
||||
};
|
||||
|
||||
// reverse sequence
|
||||
template <class Seq>
|
||||
template <typename Seq>
|
||||
struct sequence_reverse
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
static constexpr index_t NSize = Seq{}.Size();
|
||||
|
||||
using seq_split = sequence_split<Seq, NSize / 2>;
|
||||
using type = typename sequence_merge<
|
||||
@@ -297,19 +309,102 @@ struct sequence_reverse<Sequence<I0, I1>>
|
||||
using type = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
template <class Seq, class Compare>
|
||||
template <typename Seq, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
// not implemented
|
||||
template <typename SeqLeft, typename SeqRight, typename MergedSeq, typename Comp>
|
||||
struct sorted_sequence_merge_impl
|
||||
{
|
||||
static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front();
|
||||
static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front();
|
||||
|
||||
using new_merged_seq = decltype(MergedSeq::PushBack(Number<next_value>{}));
|
||||
|
||||
using new_left_seq =
|
||||
typename conditional<pick_left, decltype(SeqLeft::PopFront()), SeqLeft>::type;
|
||||
using new_right_seq =
|
||||
typename conditional<pick_left, SeqRight, decltype(SeqRight::PopFront())>::type;
|
||||
|
||||
using type =
|
||||
typename sorted_sequence_merge_impl<new_left_seq, new_right_seq, new_merged_seq, Comp>::
|
||||
type;
|
||||
};
|
||||
|
||||
template <typename SeqLeft, typename MergedSeq, typename Comp>
|
||||
struct sorted_sequence_merge_impl<SeqLeft, Sequence<>, MergedSeq, Comp>
|
||||
{
|
||||
using type = typename sequence_merge<MergedSeq, SeqLeft>::type;
|
||||
};
|
||||
|
||||
template <typename SeqRight, typename MergedSeq, typename Comp>
|
||||
struct sorted_sequence_merge_impl<Sequence<>, SeqRight, MergedSeq, Comp>
|
||||
{
|
||||
using type = typename sequence_merge<MergedSeq, SeqRight>::type;
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1, typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using type = typename sorted_sequence_merge_impl<Seq0, Seq1, Sequence<>, Comp>::type;
|
||||
};
|
||||
|
||||
using split = sequence_split<Seq, Seq::Size() / 2>;
|
||||
using unsorted_left = typename split::SeqType0;
|
||||
using unsorted_right = typename split::SeqType1;
|
||||
|
||||
using sorted_left = typename sequence_sort<unsorted_left, Compare>::type;
|
||||
using sorted_right = typename sequence_sort<unsorted_right, Compare>::type;
|
||||
|
||||
using type = typename sorted_sequence_merge<sorted_left, sorted_right, Compare>::type;
|
||||
};
|
||||
|
||||
template <class Seq, class Compare>
|
||||
template <index_t X, index_t Y, typename Compare>
|
||||
struct sequence_sort<Sequence<X, Y>, Compare>
|
||||
{
|
||||
static constexpr bool x_first = Compare{}(X, Y);
|
||||
|
||||
using type = typename conditional<x_first, Sequence<X, Y>, Sequence<Y, X>>::type;
|
||||
};
|
||||
|
||||
template <index_t X, typename Compare>
|
||||
struct sequence_sort<Sequence<X>, Compare>
|
||||
{
|
||||
using type = Sequence<X>;
|
||||
};
|
||||
|
||||
template <typename Seq, typename Less, typename Equal>
|
||||
struct sequence_unique_sort
|
||||
{
|
||||
// not implemented
|
||||
template <typename WorkInputSeq, typename WorkOutputSeq, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl
|
||||
{
|
||||
static constexpr index_t new_value = WorkInputSeq::Front();
|
||||
using new_work_input_seq = decltype(WorkInputSeq::PopFront());
|
||||
|
||||
using new_working_output_seq =
|
||||
typename conditional<new_value == WorkOutputSeq::Back(),
|
||||
WorkOutputSeq,
|
||||
decltype(WorkOutputSeq::PopBack(Number<new_value>{}))>::type;
|
||||
};
|
||||
|
||||
template <typename WorkInputSeq, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl<WorkInputSeq, Sequence<>, Eq>
|
||||
{
|
||||
using type = WorkInputSeq;
|
||||
};
|
||||
|
||||
template <typename SortedSeq, typename Eq>
|
||||
struct sorted_sequence_uniquify
|
||||
{
|
||||
using type = typename sorted_sequence_uniquify_impl<SortedSeq, Sequence<>, Eq>::type;
|
||||
};
|
||||
|
||||
using sorted_seq = typename sequence_sort<Seq, Less>::type;
|
||||
|
||||
using type = typename sorted_sequence_uniquify<sorted_seq, Equal>::type;
|
||||
};
|
||||
|
||||
template <class Seq>
|
||||
template <typename Seq>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
// not implemented yet, always return true
|
||||
@@ -317,36 +412,35 @@ struct is_valid_sequence_map
|
||||
|
||||
// TODO: add proper check for is_valid, something like:
|
||||
// static constexpr bool value =
|
||||
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
|
||||
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
|
||||
// typename sequence_sort<Seq>::SortedSeqType>{};
|
||||
};
|
||||
|
||||
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
|
||||
struct sequence_map_inverse_impl
|
||||
{
|
||||
private:
|
||||
static constexpr auto new_y2x =
|
||||
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
|
||||
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
|
||||
|
||||
public:
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
|
||||
};
|
||||
|
||||
template <class X2Y, class WorkingY2X, index_t XBegin>
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin>
|
||||
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
|
||||
{
|
||||
using type = WorkingY2X;
|
||||
};
|
||||
|
||||
template <class X2Y>
|
||||
template <typename X2Y>
|
||||
struct sequence_map_inverse
|
||||
{
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y,
|
||||
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
|
||||
typename uniform_sequence_gen<X2Y::Size(), 0>::type,
|
||||
0,
|
||||
X2Y::GetSize()>::type;
|
||||
X2Y::Size()>::type;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
@@ -457,20 +551,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
template <typename Seq>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Seq)
|
||||
{
|
||||
static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
|
||||
static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
|
||||
return sequence_pop_front(Seq::Reverse()).Reverse();
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs>
|
||||
template <typename F, index_t... Xs>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<f(Xs)...>{};
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs, index_t... Ys>
|
||||
template <typename... Seqs>
|
||||
__host__ __device__ constexpr auto merge_sequences(Seqs...)
|
||||
{
|
||||
return typename sequence_merge<Seqs...>::type{};
|
||||
}
|
||||
|
||||
template <typename F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
|
||||
@@ -478,7 +578,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
|
||||
return Sequence<f(Xs, Ys)...>{};
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
{
|
||||
@@ -489,19 +589,19 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
return Sequence<f(Xs, Ys, Zs)...>{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
template <typename Seq, typename Reduce>
|
||||
struct lambda_accumulate_on_sequence
|
||||
{
|
||||
const Reduce& f;
|
||||
@@ -512,14 +612,14 @@ struct lambda_accumulate_on_sequence
|
||||
{
|
||||
}
|
||||
|
||||
template <class IDim>
|
||||
template <typename IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim) const
|
||||
{
|
||||
return result = f(result, Seq::Get(IDim{}));
|
||||
return result = f(result, Seq::At(IDim{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr index_t
|
||||
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
|
||||
{
|
||||
@@ -530,41 +630,5 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
|
||||
return result;
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
|
||||
{
|
||||
constexpr index_t nsize = Sequence<Xs...>::GetSize();
|
||||
|
||||
static_assert(nsize <= 10, "wrong!");
|
||||
|
||||
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 5>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 6>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 7>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 8>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 9>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 10>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
93
composable_kernel/include/utility/array_helper.hpp
Normal file
93
composable_kernel/include/utility/array_helper.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
#ifndef CK_ARRAY_HELPER_HPP
|
||||
#define CK_ARRAY_HELPER_HPP
|
||||
|
||||
#include "Array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
|
||||
{
|
||||
constexpr index_t nsize = a.GetSize();
|
||||
|
||||
static_assert(nsize > 0 && nsize <= 10, "wrong!");
|
||||
|
||||
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
|
||||
|
||||
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
|
||||
|
||||
static_if<nsize == 3>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
|
||||
|
||||
static_if<nsize == 4>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
|
||||
|
||||
static_if<nsize == 5>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
|
||||
});
|
||||
|
||||
static_if<nsize == 6>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
|
||||
});
|
||||
|
||||
static_if<nsize == 7>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6]);
|
||||
});
|
||||
|
||||
static_if<nsize == 8>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7]);
|
||||
});
|
||||
|
||||
static_if<nsize == 9>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8]);
|
||||
});
|
||||
|
||||
static_if<nsize == 10>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8],
|
||||
a[9]);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -4,14 +4,19 @@
|
||||
#include "config.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "math.hpp"
|
||||
#include "vector_type.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence_helper.hpp"
|
||||
#include "Array.hpp"
|
||||
#include "array_helper.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
#include "functional4.hpp"
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// TODO: right? wrong?
|
||||
struct forwarder
|
||||
{
|
||||
template <typename T>
|
||||
@@ -17,7 +19,7 @@ struct forwarder
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <class... Ts>
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
@@ -32,7 +34,7 @@ struct static_if<true>
|
||||
{
|
||||
using Type = static_if<true>;
|
||||
|
||||
template <class F>
|
||||
template <typename F>
|
||||
__host__ __device__ constexpr auto operator()(F f) const
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
@@ -43,7 +45,7 @@ struct static_if<true>
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
template <typename F>
|
||||
__host__ __device__ static constexpr auto Else(F)
|
||||
{
|
||||
return Type{};
|
||||
@@ -55,13 +57,13 @@ struct static_if<false>
|
||||
{
|
||||
using Type = static_if<false>;
|
||||
|
||||
template <class F>
|
||||
template <typename F>
|
||||
__host__ __device__ constexpr auto operator()(F) const
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
template <typename F>
|
||||
__host__ __device__ static constexpr auto Else(F f)
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
@@ -73,5 +75,23 @@ struct static_if<false>
|
||||
}
|
||||
};
|
||||
|
||||
template <bool predicate, class X, class Y>
|
||||
struct conditional;
|
||||
|
||||
template <class X, class Y>
|
||||
struct conditional<true, X, Y>
|
||||
{
|
||||
using type = X;
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
struct conditional<false, X, Y>
|
||||
{
|
||||
using type = Y;
|
||||
};
|
||||
|
||||
template <bool predicate, class X, class Y>
|
||||
using conditional_t = typename conditional<predicate, X, Y>::type;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// F signature: F(Number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
@@ -33,7 +37,8 @@ struct static_for
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(f);
|
||||
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
|
||||
f);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -8,20 +8,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class>
|
||||
struct is_static : integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T, T X>
|
||||
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
// Orders: Sequence<...>
|
||||
@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
__host__ __device__ constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
|
||||
static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
// Orders: Sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
__host__ __device__ constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
|
||||
// dimension
|
||||
@@ -139,7 +128,8 @@ struct ford
|
||||
|
||||
for(index_t i = 0; i < ordered_lengths.Front(); ++i)
|
||||
{
|
||||
ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f, Array<index_t, 1>{i});
|
||||
detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
|
||||
Array<index_t, 1>{i});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
34
composable_kernel/include/utility/functional4.hpp
Normal file
34
composable_kernel/include/utility/functional4.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
#ifndef CK_FUNCTIONAL4_HPP
|
||||
#define CK_FUNCTIONAL4_HPP
|
||||
|
||||
#include "Sequence.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "Array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Indices>
|
||||
struct unpack_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct unpack_impl<Sequence<Is...>>
|
||||
{
|
||||
template <typename F, typename X>
|
||||
__host__ __device__ constexpr auto operator()(F f, const X& x) const
|
||||
{
|
||||
return f(x.At(Number<Is>{})...);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
__host__ __device__ constexpr auto unpack(F f, const X& x)
|
||||
{
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type>{}(f, x);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -13,54 +13,5 @@ struct integral_constant
|
||||
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
struct is_same : public integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
|
||||
template <class X>
|
||||
struct is_same<X, X> : public integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <index_t N>
|
||||
using Number = integral_constant<index_t, N>;
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
|
||||
{
|
||||
return Number<X + Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y <= X, "wrong!");
|
||||
return Number<X - Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
|
||||
{
|
||||
return Number<X * Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y > 0, "wrong!");
|
||||
return Number<X / Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y > 0, "wrong!");
|
||||
return Number<X % Y>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -104,6 +104,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
|
||||
return max(x, xs...);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct equal
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct less
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namspace ck
|
||||
|
||||
|
||||
44
composable_kernel/include/utility/number.hpp
Normal file
44
composable_kernel/include/utility/number.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
#ifndef CK_NUMBER_HPP
|
||||
#define CK_NUMBER_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
using Number = integral_constant<index_t, N>;
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
|
||||
{
|
||||
return Number<X + Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y <= X, "wrong!");
|
||||
return Number<X - Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
|
||||
{
|
||||
return Number<X * Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y > 0, "wrong!");
|
||||
return Number<X / Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y > 0, "wrong!");
|
||||
return Number<X % Y>{};
|
||||
}
|
||||
} // namespace ck
|
||||
#endif
|
||||
46
composable_kernel/include/utility/sequence_helper.hpp
Normal file
46
composable_kernel/include/utility/sequence_helper.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef CK_SEQUENCE_HELPER_HPP
|
||||
#define CK_SEQUENCE_HELPER_HPP
|
||||
|
||||
#include "Sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
|
||||
{
|
||||
constexpr index_t nsize = Sequence<Xs...>::Size();
|
||||
|
||||
static_assert(nsize <= 10, "wrong!");
|
||||
|
||||
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 5>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 6>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 7>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 8>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 9>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 10>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -2,6 +2,7 @@
|
||||
#define CK_TUPLE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "type.hpp"
|
||||
#include "Sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -16,6 +17,8 @@ struct TupleElementKey
|
||||
template <typename Key, typename Data>
|
||||
struct TupleElement
|
||||
{
|
||||
__host__ __device__ explicit constexpr TupleElement() : mData() {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
|
||||
{
|
||||
@@ -48,6 +51,12 @@ struct TupleImpl;
|
||||
template <index_t... Is, typename... Xs>
|
||||
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
|
||||
{
|
||||
#if 1
|
||||
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename... Ys>
|
||||
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
|
||||
@@ -97,5 +106,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename X, typename F, index_t... Is>
|
||||
__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename X, typename F>
|
||||
__host__ __device__ constexpr auto transpose_tuple(X& x, F f)
|
||||
{
|
||||
return detail::transpose_tuple_impl(
|
||||
x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
41
composable_kernel/include/utility/type.hpp
Normal file
41
composable_kernel/include/utility/type.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
#ifndef CK_TYPE_HPP
|
||||
#define CK_TYPE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "Sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct is_same : public integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
struct is_same<X, X> : public integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct is_static : integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, T X>
|
||||
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -115,8 +115,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
|
||||
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
|
||||
#endif
|
||||
|
||||
#if 0 // debug
|
||||
constexpr index_t GridSize =
|
||||
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
|
||||
#else
|
||||
constexpr index_t GridSize = 1;
|
||||
#endif
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
|
||||
@@ -73,19 +73,19 @@ int main(int argc, char* argv[])
|
||||
using namespace ck;
|
||||
|
||||
#if 1
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t N = 10;
|
||||
constexpr index_t C = 10;
|
||||
constexpr index_t HI = 10;
|
||||
constexpr index_t WI = 10;
|
||||
constexpr index_t K = 10;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
constexpr index_t HPad = 2;
|
||||
constexpr index_t WPad = 2;
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
|
||||
Reference in New Issue
Block a user