adding implicit gemm v3

[ROCm/composable_kernel commit: 33b5a8556b]
This commit is contained in:
Chao Liu
2019-05-16 22:23:18 -05:00
parent ffd172378a
commit dec8c3ebdd
13 changed files with 172 additions and 197 deletions

View File

@@ -88,32 +88,11 @@ struct ConstantTensorDescriptor
return accumulate_on_sequence(Lengths{}, std::multiplies<index_t>{}, Number<1>{});
}
#if 0
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct f_GetElementSpace_impl
{
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
}
__host__ __device__ constexpr index_t operator()(index_t length, index_t stride) const
{
return (length - 1) * stride;
}
};
#endif
template <class Align = Number<1>>
__host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
{
#if 0
index_t element_space_unaligned =
static_const_reduce_n<nDim>{}(f_GetElementSpace_impl{}, std::plus<index_t>{}) + 1;
#else
constexpr index_t element_space_unaligned = accumulate_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), std::plus<index_t>{}, Number<1>{});
#endif
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
}
@@ -150,10 +129,7 @@ struct ConstantTensorDescriptor
constexpr auto multi_id = Sequence<Is...>{};
constexpr auto seq_tmp =
transform_sequences(std::multiplies<index_t>{}, multi_id, GetStrides());
return accumulate_on_sequence(seq_tmp, std::plus<index_t>{}, Number<0>{});
return accumulate_on_sequence(multi_id * GetStrides(), std::plus<index_t>{}, Number<0>{});
}
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
@@ -177,14 +153,14 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
}
template <index_t IDims...>
template <index_t... IDims>
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
{
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
"wrong! too many number of dimensions to be extracted");
return make_ConstantTensorDescriptor(Lengths{}.Extract(extract_dims),
Strides{}.Extract(extract_dims));
return make_ConstantTensorDescriptor(Lengths{}.Extract(extract_dims...),
Strides{}.Extract(extract_dims...));
}
template <index_t IDim, index_t SliceLen>
@@ -195,11 +171,11 @@ struct ConstantTensorDescriptor
}
template <index_t IDim, index_t... FoldIntervals>
__host__ device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
{
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
constexpr fold_intervals_product =
constexpr index_t fold_intervals_product =
accumulate_on_sequence(fold_intervals, std::multiplies<index_t>{}, Number<1>{});
constexpr auto unfold_length = GetLength(Number<IDim>{});
@@ -207,29 +183,31 @@ struct ConstantTensorDescriptor
// length of the dimension to be folded needs to be dividable by fold_interval_product,
// otherwise, folding is invalid
static_assert(unfold_length % fold_interval_product == 0,
static_assert(unfold_length % fold_intervals_product == 0,
"wrong! length on the dimension to be folded cannot be evenly divided!");
// folded lengths
constexpr auto fold_lengths =
Sequence<unfold_length / fold_interval_product>{}.Append(fold_intervals);
Sequence<unfold_length / fold_intervals_product>{}.Append(fold_intervals);
// folded strides
constexpr auto fold_strides = transform_sequences(mod_conv::scales<index_t, unfold_stride>{},
constexpr auto fold_strides =
Number<unfold_stride>{} *
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
// left and right lengths
constexpr auto lengths_pair = GetLengths().Split(Number<I>{});
constexpr auto lengths_pair = GetLengths().Split(Number<IDim>{});
constexpr auto left_lengths = lengths_pair.first;
constexpr auto right_lengths = lengths_pair.second.PopFront();
// left and right strides
constexpr auto strides_pair = GetStrides().Split(Number<I>{});
constexpr auto strides_pair = GetStrides().Split(Number<IDim>{});
constexpr auto left_strides = strides_pair.first;
constexpr auto right_strides = strides_pair.second.PopFront();
return make_ConstantTensorDescriptor(left_lengths.Append(fold_lengths).Append(right_lengths),
left_strides.Append(fold_strides).Append(right_strides));
return make_ConstantTensorDescriptor(
left_lengths.Append(fold_lengths).Append(right_lengths),
left_strides.Append(fold_strides).Append(right_strides));
}
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
@@ -264,8 +242,8 @@ struct ConstantTensorDescriptor
constexpr index_t unfold_length =
accumulate_on_sequence(fold_lengths, std::multiplies<index_t>{}, Number<1>{});
constexpr auto new_strides =
left_strides.PopBack(Number<unfold_strides>{}).Append(right_strides);
constexpr auto new_lengths =
left_lengths.PopBack(Number<unfold_length>{}).Append(right_lengths);
// strides
constexpr auto strides_pair1 = Strides{}.Split(Number<LastUnfoldDim + 1>{});
@@ -281,7 +259,7 @@ struct ConstantTensorDescriptor
constexpr index_t unfold_stride = fold_strides.Back();
constexpr auto new_strides =
left_strides.PushBack(Number<unfold_strides>{}).Append(right_strides);
left_strides.PushBack(Number<unfold_stride>{}).Append(right_strides);
return make_ConstantTensorDescriptor(new_lengths, new_strides);
}
@@ -289,7 +267,7 @@ struct ConstantTensorDescriptor
template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(IRs) == GetNumberOfDimension(), "wrong! dimension is wrong");
static_assert(sizeof...(IRs) == GetNumOfDimension(), "wrong! dimension is wrong");
constexpr auto map_new2old = Sequence<IRs...>{};
return make_ConstantTensorDescriptor(Lengths{}.ReorderGivenNew2Old(map_new2old),
Strides{}.ReorderGivenNew2Old(map_new2old));