mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
adding merge transform
This commit is contained in:
@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
#elif 1
|
||||
// create a native tensor descriptor
|
||||
constexpr auto in_c_h_w_n_global_desc =
|
||||
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto pad_h_w = Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{};
|
||||
constexpr auto pass_c = PassThrough<C>{};
|
||||
constexpr auto pass_n = PassThrough<N>{};
|
||||
// transformation: {c, h, w, n} --> {n, c, hp, wp}
|
||||
// {h, w} --> {hp, wp}, {c} --> {c}, {n} --> {n}
|
||||
constexpr auto in_n_c_hp_wp_global_desc = transform_tensor_descriptor(
|
||||
in_c_h_w_n_global_desc,
|
||||
make_tuple(
|
||||
Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{}, PassThrough<C>{}, PassThrough<N>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<0>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<2, 3>{}, Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n);
|
||||
constexpr auto lower_dim_groups =
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
|
||||
constexpr auto upper_dim_groups =
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
|
||||
|
||||
constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor(
|
||||
in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups);
|
||||
#if 1
|
||||
// transformation: {n, c, hp, wp} --> {c, b}
|
||||
// {n, hp, wp} --> {b}, {c} --> {c}
|
||||
constexpr auto in_c_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hp_wp_global_desc,
|
||||
make_tuple(Merge<decltype(in_n_c_hp_wp_global_desc.GetLengths(I0, I2, I3))>{},
|
||||
PassThrough<in_n_c_hp_wp_global_desc.GetLength(I1)>{}),
|
||||
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
// 0
|
||||
print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc);
|
||||
|
||||
printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4}));
|
||||
// 1
|
||||
print_tensor_descriptor("in_n_c_hp_wp_global_desc", in_n_c_hp_wp_global_desc);
|
||||
|
||||
printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4}));
|
||||
// 2
|
||||
print_tensor_descriptor("in_c_b_global_desc", in_c_b_global_desc);
|
||||
|
||||
constexpr auto idx2 = MultiIndex<2>{1, 4 * (16 * 16) + 5 * 16 + 6};
|
||||
auto idx1 = in_c_b_global_desc.CalculateLowerIndex(idx2);
|
||||
auto idx0 = in_c_b_global_desc.GetLowerTensorDescriptor().CalculateLowerIndex(idx1);
|
||||
|
||||
print_array("idx2: ", idx2);
|
||||
print_array("idx1: ", idx1);
|
||||
print_array("idx0: ", idx0);
|
||||
|
||||
printf("in_c_b_global_desc offset: %lu\n", in_c_b_global_desc.CalculateOffset(idx2));
|
||||
}
|
||||
#else
|
||||
{
|
||||
index_t c = static_cast<index_t>(threadIdx.x);
|
||||
index_t h = static_cast<index_t>(threadIdx.y);
|
||||
index_t w = static_cast<index_t>(threadIdx.z);
|
||||
|
||||
p_out_global[0] = in_n_c_h_w_padded_global_desc.CalculateOffset({1, c, h, w});
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -18,9 +18,9 @@ struct NativeDimension
|
||||
|
||||
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(index_t i) { return i * Stride; }
|
||||
__host__ __device__ static constexpr index_t CalculateOffset(index_t i) { return i * Stride; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t i_diff)
|
||||
__host__ __device__ static constexpr index_t CalculateOffsetDiff(index_t i_diff)
|
||||
{
|
||||
return i_diff * Stride;
|
||||
}
|
||||
|
||||
@@ -22,9 +22,12 @@ struct PassThrough
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; }
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
return idx_up;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
@@ -36,7 +39,7 @@ struct PassThrough
|
||||
template <typename LowLengths, typename LeftPads, typename RightPads>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
static constexpr index_t nDim = LowLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
@@ -52,12 +55,12 @@ struct Pad
|
||||
return GetLowerLengths() + LeftPads{} + RightPads{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
return idx_up - LeftPads{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
@@ -65,21 +68,20 @@ struct Pad
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
#if 0
|
||||
// LowLengths: Sequence<...>
|
||||
template <typename LowLengths>
|
||||
struct Merge
|
||||
{
|
||||
static constexpr index_t nDimLow = LowLengths::GetSize();
|
||||
static constexpr index_t nDimLow = LowLengths::Size();
|
||||
static constexpr index_t nDimUp = 1;
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
@@ -88,18 +90,56 @@ struct Merge
|
||||
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
// emulate constexpr lambda
|
||||
template <typename PseudoLowStrides>
|
||||
struct lambda_CalculateLowerIndex
|
||||
{
|
||||
index_t& itmp;
|
||||
LowerIndex& idx_low;
|
||||
|
||||
__host__ __device__ explicit constexpr lambda_CalculateLowerIndex(index_t& itmp_,
|
||||
LowerIndex& idx_low_)
|
||||
: itmp(itmp_), idx_low(idx_low_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename IDim>
|
||||
__host__ __device__ constexpr void operator()(IDim idim) const
|
||||
{
|
||||
constexpr index_t stride = PseudoLowStrides::At(idim);
|
||||
idx_low(idim) = itmp / stride;
|
||||
itmp -= idx_low[idim] * stride;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low;
|
||||
|
||||
// not implemeneted
|
||||
index_t itmp = idx_up[0];
|
||||
|
||||
constexpr auto pseudo_low_strides =
|
||||
reverse_inclusive_scan_sequence(
|
||||
GetLowerLengths().PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
#if 1
|
||||
static_for<0, nDimLow - 1, 1>{}(
|
||||
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
|
||||
|
||||
idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1];
|
||||
#else
|
||||
static_for<0, nDimLow, 1>{}(
|
||||
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
|
||||
#endif
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff,
|
||||
LowerIndex idx_low_old)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const LowerIndex& idx_low_old)
|
||||
{
|
||||
LowerIndex idx_low_diff;
|
||||
|
||||
@@ -110,49 +150,48 @@ struct Merge
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
};
|
||||
#endif
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
template <index_t LowLength, typename UpLengths>
|
||||
template <typename UpLengths>
|
||||
struct Unmerge
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
static constexpr index_t nDimUp = UpLengths::Size();
|
||||
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
|
||||
__host__ __device__ constexpr Unmerge()
|
||||
{
|
||||
static_assert(LowLength == accumulate_on_sequence(
|
||||
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
|
||||
"wrong! UpLengths need to be ");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths()
|
||||
{
|
||||
constexpr index_t low_length =
|
||||
accumulate_on_sequence(UpLengths{}, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
return Sequence<low_length>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
|
||||
math::multiplies<index_t>,
|
||||
1>::type{};
|
||||
|
||||
LowerIndex idx_low{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}([&](auto idim) { idx_low(0) += idx_up[idim] * scans[idim]; });
|
||||
constexpr auto pseudo_up_strides =
|
||||
typename sequence_reverse_inclusive_scan<UpLengths, math::multiplies<index_t>, 1>::
|
||||
type{};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff)
|
||||
{
|
||||
return GetLowerIndex(idx_up_diff);
|
||||
return CalculateLowerIndex(idx_up_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
static constexpr index_t nDimUp = UpLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ constexpr Embed()
|
||||
__host__ __device__ explicit constexpr Embed()
|
||||
{
|
||||
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
|
||||
"wrong! # of dimensions not consistent");
|
||||
@@ -191,7 +230,7 @@ struct Embed
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low(Coefficients{}[nDimUp]);
|
||||
|
||||
@@ -201,7 +240,7 @@ struct Embed
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff)
|
||||
{
|
||||
LowerIndex idx_low_diff{0};
|
||||
|
||||
|
||||
@@ -18,34 +18,6 @@ struct NativeTensorDescriptor
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
struct lambda_GetLength
|
||||
{
|
||||
template <typename IDim>
|
||||
__host__ __device__ constexpr auto operator()(IDim) const
|
||||
{
|
||||
return GetLength(IDim{});
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return typename sequence_gen<nDim, lambda_GetLength>::type{};
|
||||
}
|
||||
|
||||
struct lambda_GetStride
|
||||
{
|
||||
template <typename IDim>
|
||||
__host__ __device__ constexpr auto operator()(IDim) const
|
||||
{
|
||||
return GetStride(IDim{});
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides()
|
||||
{
|
||||
return typename sequence_gen<nDim, lambda_GetStride>::type{};
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
|
||||
{
|
||||
@@ -58,7 +30,41 @@ struct NativeTensorDescriptor
|
||||
return mDimensions.At(Number<IDim>{}).GetStride();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(const Index& idx)
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
|
||||
{
|
||||
return Sequence<GetLength(Number<IDims>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetStrides(Sequence<IDims...>)
|
||||
{
|
||||
return Sequence<GetStride(Number<IDims>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
|
||||
{
|
||||
return GetLengths(Sequence<IDim, IDims...>{});
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetStrides(Number<IDim>, Number<IDims>...)
|
||||
{
|
||||
return GetStrides(Sequence<IDim, IDims...>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides()
|
||||
{
|
||||
return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateOffset(const Index& idx)
|
||||
{
|
||||
index_t offset = 0;
|
||||
|
||||
@@ -67,7 +73,7 @@ struct NativeTensorDescriptor
|
||||
return offset;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff)
|
||||
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
|
||||
{
|
||||
index_t offset_diff = 0;
|
||||
|
||||
@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor
|
||||
// UpDimensionIds should include all up-dimensions
|
||||
|
||||
// TODO: sanity check: while a up-dimension could be associated with multille
|
||||
// transformation,
|
||||
// a low-dimension should be associated with only one transformation
|
||||
// transformation, a low-dimension should be associated with only one transformation
|
||||
|
||||
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
|
||||
// of lower-tensor-descriptor
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension()
|
||||
@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor
|
||||
return GetNumOfUpperDimension();
|
||||
}
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
struct lambda_get_upper_lengths
|
||||
{
|
||||
template <typename Transform>
|
||||
__host__ __device__ constexpr auto operator()(Transform tran) const
|
||||
{
|
||||
return tran.GetUpperLengths();
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto tuple_of_upper_lengths =
|
||||
transform_tuple(Transforms, lambda_get_upper_lengths{});
|
||||
|
||||
constexpr auto all_upper_lengths = merge_tuple_of_sequences(tuple_of_upper_lengths);
|
||||
|
||||
constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{});
|
||||
|
||||
// TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions
|
||||
// TODO: sanity-check all_upper_lengths have no conflicting upper-length
|
||||
|
||||
using sort_dimension_ids =
|
||||
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
|
||||
|
||||
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
|
||||
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
|
||||
|
||||
constexpr auto sorted_upper_lengths =
|
||||
sequence_element_pick(all_upper_lengths, sorted2unsorted_map);
|
||||
|
||||
return sorted_upper_lengths;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
|
||||
{
|
||||
return LowTensorDescriptor{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up)
|
||||
__host__ __device__ static constexpr auto GetLowerLengths()
|
||||
{
|
||||
return GetLowerTensorDescriptor().GetLengths();
|
||||
}
|
||||
|
||||
struct lambda_GetUpperLengths
|
||||
{
|
||||
template <typename Transform>
|
||||
__host__ __device__ constexpr auto operator()(const Transform& tran) const
|
||||
{
|
||||
return tran.GetUpperLengths();
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
constexpr auto tuple_of_up_lengths =
|
||||
transform_tuple(lambda_GetUpperLengths{}, Transforms{});
|
||||
|
||||
constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths);
|
||||
|
||||
constexpr auto mingled_up_dimension_ids =
|
||||
unpack(lambda_merge_sequences{}, UpDimensionIds{});
|
||||
|
||||
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
|
||||
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
|
||||
|
||||
// sort by upper-dimension-ids
|
||||
using sort_up_dimension_ids = sequence_unique_sort<decltype(mingled_up_dimension_ids),
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>;
|
||||
|
||||
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
|
||||
static_assert(is_same<typename sort_up_dimension_ids::type,
|
||||
typename arithmetic_sequence_gen<0, nDimUp, 1>::type>{},
|
||||
"wrong! UpDimensionIds is not configured correctly");
|
||||
|
||||
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
|
||||
|
||||
constexpr auto sorted_up_lengths =
|
||||
pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map);
|
||||
|
||||
return sorted_up_lengths;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
|
||||
{
|
||||
return GetLengths()[IDim];
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
|
||||
{
|
||||
return Sequence<GetLength(Number<IDims>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
|
||||
{
|
||||
return GetLengths(Sequence<IDim, IDims...>{});
|
||||
}
|
||||
|
||||
// TODO: right now return value is constexpr because use of non-constepxr lambda
|
||||
__host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low;
|
||||
|
||||
@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor
|
||||
// this assume each lower (single) index is only assocaited with one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
// of TransformedTensorDescriptor
|
||||
idx_low_part = tran.GetLowerIndex(to_array(idx_up_part));
|
||||
idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part));
|
||||
});
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const LowerIndex& idx_low_old)
|
||||
// TODO: right now return value is constexpr because use of non-constepxr lambda
|
||||
__host__ __device__ static constexpr LowerIndex
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, const LowerIndex& idx_low_old)
|
||||
{
|
||||
LowerIndex idx_low_diff;
|
||||
|
||||
@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor
|
||||
// this assume each lower (single) index is associated with only one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
// of TransformedTensorDescriptor
|
||||
idx_low_diff_part = tran.GetLowerIndex(idx_up_diff_part, idx_low_old_part);
|
||||
idx_low_diff_part = tran.CalculateLowerIndex(idx_up_diff_part, idx_low_old_part);
|
||||
});
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up)
|
||||
__host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up)
|
||||
{
|
||||
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
|
||||
return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
|
||||
}
|
||||
|
||||
#if 0
|
||||
@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor
|
||||
};
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>,
|
||||
Sequence<Strides...>)
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
|
||||
Sequence<Strides...>)
|
||||
{
|
||||
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
|
||||
}
|
||||
|
||||
template <typename Lengths>
|
||||
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths)
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
|
||||
{
|
||||
constexpr auto strides = reverse_inclusive_scan_sequence(
|
||||
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
|
||||
@@ -7,12 +7,19 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename... NativeDimensions>
|
||||
__host__ __device__ void print_tensor_descriptor(const char* s,
|
||||
NativeTensorDescriptor<NativeDimensions...> desc)
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ void print_tensor_descriptor(const char* s,
|
||||
const TransformedTensorDescriptor<Ts...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths());
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
|
||||
@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t... Lengths>
|
||||
__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 2>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 3>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 4>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 5>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 6>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 7>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#ifndef CK_ARRAY_HPP
|
||||
#define CK_ARRAY_HPP
|
||||
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "functional2.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -17,7 +17,7 @@ struct Array
|
||||
__host__ __device__ explicit constexpr Array() {}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
__host__ __device__ explicit constexpr Array(X x, Xs... xs)
|
||||
__host__ __device__ constexpr Array(X x, Xs... xs)
|
||||
: mData{static_cast<TData>(x), static_cast<TData>(xs)...}
|
||||
{
|
||||
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
|
||||
@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
|
||||
return ArrayElementPicker<Arr, Picks>(a);
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto to_array(const T& x)
|
||||
{
|
||||
@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x)
|
||||
|
||||
return y;
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: remove this
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
|
||||
{
|
||||
@@ -1,12 +1,12 @@
|
||||
#ifndef CK_ARRAY_HELPER_HPP
|
||||
#define CK_ARRAY_HELPER_HPP
|
||||
|
||||
#include "Array.hpp"
|
||||
#include "array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
|
||||
__host__ __device__ void print_array(const char* s, Array<T, NSize> a)
|
||||
{
|
||||
constexpr index_t nsize = a.GetSize();
|
||||
|
||||
@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
#include "tuple.hpp"
|
||||
#include "math.hpp"
|
||||
#include "vector_type.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "sequence_helper.hpp"
|
||||
#include "Array.hpp"
|
||||
#include "array.hpp"
|
||||
#include "array_helper.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_FUNCTIONAL_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_FUNCTIONAL2_HPP
|
||||
|
||||
#include "functional.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "Array.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#ifndef CK_FUNCTIONAL4_HPP
|
||||
#define CK_FUNCTIONAL4_HPP
|
||||
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "Array.hpp"
|
||||
#include "array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace math {
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
#define CK_SEQUENCE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "type.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -155,8 +157,8 @@ struct Sequence
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
using seq_split = sequence_split<Type, I>;
|
||||
constexpr auto seq_left = typename seq_split::SeqType0{};
|
||||
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
|
||||
constexpr auto seq_left = typename seq_split::left_type{};
|
||||
constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
|
||||
|
||||
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
|
||||
}
|
||||
@@ -188,34 +190,34 @@ struct sequence_merge<Seq>
|
||||
};
|
||||
|
||||
// generate sequence
|
||||
template <index_t IBegin, index_t NRemain, typename F>
|
||||
struct sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
|
||||
static constexpr index_t IMiddle = IBegin + NRemainLeft;
|
||||
|
||||
using type =
|
||||
typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
|
||||
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I, typename F>
|
||||
struct sequence_gen_impl<I, 1, F>
|
||||
{
|
||||
static constexpr index_t Is = F{}(Number<I>{});
|
||||
using type = Sequence<Is>;
|
||||
};
|
||||
|
||||
template <index_t I, typename F>
|
||||
struct sequence_gen_impl<I, 0, F>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
template <index_t NSize, typename F>
|
||||
struct sequence_gen
|
||||
{
|
||||
template <index_t IBegin, index_t NRemain, typename G>
|
||||
struct sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
|
||||
static constexpr index_t IMiddle = IBegin + NRemainLeft;
|
||||
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
|
||||
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 1, G>
|
||||
{
|
||||
static constexpr index_t Is = G{}(Number<I>{});
|
||||
using type = Sequence<Is>;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 0, G>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
using type = typename sequence_gen_impl<0, NSize, F>::type;
|
||||
};
|
||||
|
||||
@@ -281,8 +283,8 @@ struct sequence_split
|
||||
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
|
||||
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
|
||||
|
||||
using SeqType0 = decltype(Seq::Extract(range0{}));
|
||||
using SeqType1 = decltype(Seq::Extract(range1{}));
|
||||
using left_type = decltype(Seq::Extract(range0{}));
|
||||
using right_type = decltype(Seq::Extract(range1{}));
|
||||
};
|
||||
|
||||
// reverse sequence
|
||||
@@ -293,8 +295,8 @@ struct sequence_reverse
|
||||
|
||||
using seq_split = sequence_split<Seq, NSize / 2>;
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_reverse<typename seq_split::SeqType1>::type,
|
||||
typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
|
||||
typename sequence_reverse<typename seq_split::right_type>::type,
|
||||
typename sequence_reverse<typename seq_split::left_type>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>>
|
||||
using type = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
template <typename Seq, typename Compare>
|
||||
struct sequence_sort
|
||||
template <typename Values, typename Ids, typename Compare>
|
||||
struct sequence_sort_impl
|
||||
{
|
||||
template <typename SeqLeft, typename SeqRight, typename MergedSeq, typename Comp>
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl
|
||||
{
|
||||
static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front();
|
||||
static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front();
|
||||
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
|
||||
|
||||
using new_merged_seq = decltype(MergedSeq::PushBack(Number<next_value>{}));
|
||||
static constexpr index_t chosen_value =
|
||||
choose_left ? LeftValues::Front() : RightValues::Front();
|
||||
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
|
||||
|
||||
using new_left_seq =
|
||||
typename conditional<pick_left, decltype(SeqLeft::PopFront()), SeqLeft>::type;
|
||||
using new_right_seq =
|
||||
typename conditional<pick_left, SeqRight, decltype(SeqRight::PopFront())>::type;
|
||||
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
|
||||
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
|
||||
|
||||
using new_left_values =
|
||||
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
|
||||
using new_left_ids =
|
||||
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
|
||||
|
||||
using new_right_values =
|
||||
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
|
||||
using new_right_ids =
|
||||
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
|
||||
|
||||
using merge = sorted_sequence_merge_impl<new_left_values,
|
||||
new_left_ids,
|
||||
new_right_values,
|
||||
new_right_ids,
|
||||
new_merged_values,
|
||||
new_merged_ids,
|
||||
Comp>;
|
||||
// this is output
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
|
||||
};
|
||||
|
||||
template <typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
RightValues,
|
||||
RightIds,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using merge = sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
RightValues,
|
||||
RightIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
Comp>;
|
||||
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
static constexpr index_t nsize = Values::Size();
|
||||
|
||||
using split_unsorted_values = sequence_split<Values, nsize / 2>;
|
||||
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
|
||||
|
||||
using left_unsorted_values = typename split_unsorted_values::left_type;
|
||||
using left_unsorted_ids = typename split_unsorted_ids::left_type;
|
||||
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
|
||||
using left_sorted_values = typename left_sort::sorted_values;
|
||||
using left_sorted_ids = typename left_sort::sorted_ids;
|
||||
|
||||
using right_unsorted_values = typename split_unsorted_values::right_type;
|
||||
using right_unsorted_ids = typename split_unsorted_ids::right_type;
|
||||
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
|
||||
using right_sorted_values = typename right_sort::sorted_values;
|
||||
using right_sorted_ids = typename right_sort::sorted_ids;
|
||||
|
||||
using merged_sorted = sorted_sequence_merge<left_sorted_values,
|
||||
left_sorted_ids,
|
||||
right_sorted_values,
|
||||
right_sorted_ids,
|
||||
Compare>;
|
||||
|
||||
using sorted_values = typename merged_sorted::merged_values;
|
||||
using sorted_ids = typename merged_sorted::merged_ids;
|
||||
};
|
||||
|
||||
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
|
||||
{
|
||||
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
|
||||
|
||||
using sorted_values =
|
||||
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
|
||||
};
|
||||
|
||||
template <index_t Value, index_t Id, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
|
||||
{
|
||||
using sorted_values = Sequence<Value>;
|
||||
using sorted_ids = Sequence<Id>;
|
||||
};
|
||||
|
||||
template <typename Values, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
|
||||
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
|
||||
|
||||
// this is output
|
||||
using type = typename sort::sorted_values;
|
||||
using sorted2unsorted_map = typename sort::sorted_ids;
|
||||
};
|
||||
|
||||
template <typename Values, typename Less, typename Equal>
|
||||
struct sequence_unique_sort
|
||||
{
|
||||
template <typename RemainValues,
|
||||
typename RemainIds,
|
||||
typename UniquifiedValues,
|
||||
typename UniquifiedIds,
|
||||
typename Eq>
|
||||
struct sorted_sequence_uniquify_impl
|
||||
{
|
||||
static constexpr index_t current_value = RemainValues::Front();
|
||||
static constexpr index_t current_id = RemainIds::Front();
|
||||
|
||||
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
|
||||
|
||||
using new_remain_values = decltype(RemainValues::PopFront());
|
||||
using new_remain_ids = decltype(RemainIds::PopFront());
|
||||
|
||||
using new_uniquified_values =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
|
||||
UniquifiedValues>::type;
|
||||
|
||||
using new_uniquified_ids =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
|
||||
UniquifiedIds>::type;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
|
||||
new_remain_ids,
|
||||
new_uniquified_values,
|
||||
new_uniquified_ids,
|
||||
Eq>;
|
||||
|
||||
// this is output
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
UniquifiedValues,
|
||||
UniquifiedIds,
|
||||
Eq>
|
||||
{
|
||||
using uniquified_values = UniquifiedValues;
|
||||
using uniquified_ids = UniquifiedIds;
|
||||
};
|
||||
|
||||
template <typename SortedValues, typename SortedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify
|
||||
{
|
||||
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
|
||||
decltype(SortedIds::PopFront()),
|
||||
Sequence<SortedValues::Front()>,
|
||||
Sequence<SortedIds::Front()>,
|
||||
Eq>;
|
||||
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
using sort = sequence_sort<Values, Less>;
|
||||
using sorted_values = typename sort::type;
|
||||
using sorted_ids = typename sort::sorted2unsorted_map;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
|
||||
|
||||
// this is output
|
||||
using type = typename uniquify::uniquified_values;
|
||||
using sorted2unsorted_map = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
template <typename SeqMap>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
static constexpr bool value =
|
||||
is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
|
||||
typename sequence_sort<SeqMap, math::less<index_t>>::type>{};
|
||||
};
|
||||
|
||||
template <typename SeqMap>
|
||||
struct sequence_map_inverse
|
||||
{
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
|
||||
struct sequence_map_inverse_impl
|
||||
{
|
||||
static constexpr auto new_y2x =
|
||||
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
|
||||
|
||||
using type =
|
||||
typename sorted_sequence_merge_impl<new_left_seq, new_right_seq, new_merged_seq, Comp>::
|
||||
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
|
||||
type;
|
||||
};
|
||||
|
||||
template <typename SeqLeft, typename MergedSeq, typename Comp>
|
||||
struct sorted_sequence_merge_impl<SeqLeft, Sequence<>, MergedSeq, Comp>
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin>
|
||||
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
|
||||
{
|
||||
using type = typename sequence_merge<MergedSeq, SeqLeft>::type;
|
||||
using type = WorkingY2X;
|
||||
};
|
||||
|
||||
template <typename SeqRight, typename MergedSeq, typename Comp>
|
||||
struct sorted_sequence_merge_impl<Sequence<>, SeqRight, MergedSeq, Comp>
|
||||
{
|
||||
using type = typename sequence_merge<MergedSeq, SeqRight>::type;
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1, typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using type = typename sorted_sequence_merge_impl<Seq0, Seq1, Sequence<>, Comp>::type;
|
||||
};
|
||||
|
||||
using split = sequence_split<Seq, Seq::Size() / 2>;
|
||||
using unsorted_left = typename split::SeqType0;
|
||||
using unsorted_right = typename split::SeqType1;
|
||||
|
||||
using sorted_left = typename sequence_sort<unsorted_left, Compare>::type;
|
||||
using sorted_right = typename sequence_sort<unsorted_right, Compare>::type;
|
||||
|
||||
using type = typename sorted_sequence_merge<sorted_left, sorted_right, Compare>::type;
|
||||
};
|
||||
|
||||
template <index_t X, index_t Y, typename Compare>
|
||||
struct sequence_sort<Sequence<X, Y>, Compare>
|
||||
{
|
||||
static constexpr bool x_first = Compare{}(X, Y);
|
||||
|
||||
using type = typename conditional<x_first, Sequence<X, Y>, Sequence<Y, X>>::type;
|
||||
};
|
||||
|
||||
template <index_t X, typename Compare>
|
||||
struct sequence_sort<Sequence<X>, Compare>
|
||||
{
|
||||
using type = Sequence<X>;
|
||||
};
|
||||
|
||||
template <typename Seq, typename Less, typename Equal>
|
||||
struct sequence_unique_sort
|
||||
{
|
||||
template <typename WorkInputSeq, typename WorkOutputSeq, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl
|
||||
{
|
||||
static constexpr index_t new_value = WorkInputSeq::Front();
|
||||
using new_work_input_seq = decltype(WorkInputSeq::PopFront());
|
||||
|
||||
using new_working_output_seq =
|
||||
typename conditional<new_value == WorkOutputSeq::Back(),
|
||||
WorkOutputSeq,
|
||||
decltype(WorkOutputSeq::PopBack(Number<new_value>{}))>::type;
|
||||
};
|
||||
|
||||
template <typename WorkInputSeq, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl<WorkInputSeq, Sequence<>, Eq>
|
||||
{
|
||||
using type = WorkInputSeq;
|
||||
};
|
||||
|
||||
template <typename SortedSeq, typename Eq>
|
||||
struct sorted_sequence_uniquify
|
||||
{
|
||||
using type = typename sorted_sequence_uniquify_impl<SortedSeq, Sequence<>, Eq>::type;
|
||||
};
|
||||
|
||||
using sorted_seq = typename sequence_sort<Seq, Less>::type;
|
||||
|
||||
using type = typename sorted_sequence_uniquify<sorted_seq, Equal>::type;
|
||||
};
|
||||
|
||||
template <typename Seq>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
// not implemented yet, always return true
|
||||
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
|
||||
|
||||
// TODO: add proper check for is_valid, something like:
|
||||
// static constexpr bool value =
|
||||
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
|
||||
// typename sequence_sort<Seq>::SortedSeqType>{};
|
||||
};
|
||||
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
|
||||
struct sequence_map_inverse_impl
|
||||
{
|
||||
private:
|
||||
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
|
||||
|
||||
public:
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
|
||||
};
|
||||
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin>
|
||||
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
|
||||
{
|
||||
using type = WorkingY2X;
|
||||
};
|
||||
|
||||
template <typename X2Y>
|
||||
struct sequence_map_inverse
|
||||
{
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y,
|
||||
typename uniform_sequence_gen<X2Y::Size(), 0>::type,
|
||||
typename sequence_map_inverse_impl<SeqMap,
|
||||
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
|
||||
0,
|
||||
X2Y::Size()>::type;
|
||||
SeqMap::Size()>::type;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
|
||||
}
|
||||
|
||||
template <typename Seq, index_t... Is>
|
||||
__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence<Is...>)
|
||||
{
|
||||
return Sequence<Seq::At(Number<Is>{})...>{};
|
||||
}
|
||||
|
||||
template <typename Seq, typename Reduce>
|
||||
struct lambda_accumulate_on_sequence
|
||||
{
|
||||
@@ -1,12 +1,12 @@
|
||||
#ifndef CK_SEQUENCE_HELPER_HPP
|
||||
#define CK_SEQUENCE_HELPER_HPP
|
||||
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
|
||||
__host__ __device__ void print_sequence(const char* s, Sequence<Xs...>)
|
||||
{
|
||||
constexpr index_t nsize = Sequence<Xs...>::Size();
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "type.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename X, typename F, index_t... Is>
|
||||
__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence<Is...>)
|
||||
template <typename F, typename X, index_t... Is>
|
||||
__host__ __device__ constexpr auto transform_tuple_impl(F f, const X& x, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename X, typename F>
|
||||
__host__ __device__ constexpr auto transpose_tuple(X& x, F f)
|
||||
template <typename F, typename X>
|
||||
__host__ __device__ constexpr auto transform_tuple(F f, const X& x)
|
||||
{
|
||||
return detail::transpose_tuple_impl(
|
||||
x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
return detail::transform_tuple_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
#define CK_TYPE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "Sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t... Is>
|
||||
struct Sequence;
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct is_same : public integral_constant<bool, false>
|
||||
{
|
||||
|
||||
@@ -84,8 +84,8 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
constexpr index_t HPad = 2;
|
||||
constexpr index_t WPad = 2;
|
||||
constexpr index_t HPad = 3;
|
||||
constexpr index_t WPad = 3;
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
|
||||
Reference in New Issue
Block a user