adding implicit gemm v4 (nchw, kcyx)

This commit is contained in:
Chao Liu
2019-05-30 17:50:49 -05:00
parent 0a2657312e
commit b2439ec9dd
11 changed files with 440 additions and 130 deletions

View File

@@ -4,7 +4,7 @@
template <class Lengths>
__host__ __device__ constexpr auto calculate_tensor_strides_default_rank_packed(Lengths)
{
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), std::multiplies<index_t>{})
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), mod_conv::multiplies<index_t>{})
.PushBack(Number<1>{});
}
@@ -95,7 +95,7 @@ struct ConstantTensorDescriptor
__host__ __device__ static constexpr index_t GetElementSize()
{
return accumulate_on_sequence(Lengths{}, std::multiplies<index_t>{}, Number<1>{});
return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
}
// WRONG! ReorderGivenOld2New is broken
@@ -107,10 +107,10 @@ struct ConstantTensorDescriptor
constexpr auto strides_in_rank = GetStrides().ReorderGivenOld2new(MemoryRank{});
constexpr index_t element_space_unaligned = accumulate_on_sequence(
(lengths_in_rank - Number<1>{}) * strides_in_rank, std::plus<index_t>{}, Number<1>{});
(lengths_in_rank - Number<1>{}) * strides_in_rank, mod_conv::plus<index_t>{}, Number<1>{});
#else // WRONG! align shouldbe applied to the last memory rank, not the last tensor dimension
constexpr index_t element_space_unaligned = accumulate_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), std::plus<index_t>{}, Number<1>{});
(GetLengths() - Number<1>{}) * GetStrides(), mod_conv::plus<index_t>{}, Number<1>{});
#endif
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
@@ -144,7 +144,8 @@ struct ConstantTensorDescriptor
constexpr auto multi_id = Sequence<Is...>{};
return accumulate_on_sequence(multi_id * GetStrides(), std::plus<index_t>{}, Number<0>{});
return accumulate_on_sequence(
multi_id * GetStrides(), mod_conv::plus<index_t>{}, Number<0>{});
}
#if 0 // ReorderGivenOld2new is broken
@@ -197,32 +198,70 @@ struct ConstantTensorDescriptor
// This function doesn't do carry check on the highest dimension, for performance reason.
// It is the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound
// on the highest dimension
template <bool PositiveDirection>
__host__ __device__ static Array<index_t, nDim>
UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
index_t step_size_of_1d_index)
index_t step_size_of_1d_index,
integral_constant<bool, PositiveDirection>)
{
auto new_multi_id = old_multi_id + GetMultiIndexFrom1dIndex(step_size_of_1d_index);
Array<index_t, nDim> new_multi_id;
bool carry = false;
const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index);
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
constexpr auto IDim = Number<idim>{};
static_if<PositiveDirection>{}([&](auto) {
new_multi_id = old_multi_id + step_sizes;
if(carry)
{
++new_multi_id[idim];
}
bool carry = false;
carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
constexpr auto IDim = Number<idim>{};
if(new_multi_id[idim] >= GetLength(IDim))
{
new_multi_id[idim] -= GetLength(IDim);
carry = true;
}
if(carry)
{
++new_multi_id[idim];
}
carry = false;
if(new_multi_id[idim] >= GetLength(IDim))
{
new_multi_id[idim] -= GetLength(IDim);
carry = true;
}
});
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
new_multi_id = old_multi_id + (GetLengths() - step_sizes);
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) {
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
constexpr auto IDim = Number<idim>{};
if(borrow)
{
--new_multi_id[idim];
}
borrow = false;
if(new_multi_id[idim] < GetLength(IDim))
{
new_multi_id[idim] += GetLength(IDim);
borrow = true;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
new_multi_id = new_multi_id - GetLengths();
});
return new_multi_id;
@@ -255,7 +294,7 @@ struct ConstantTensorDescriptor
}
template <class... Ts>
__host__ __device__ static constexpr auto Inject(ConstantTensorDescriptor<Ts...>)
__host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor<Ts...>)
{
using leaf_tensor = ConstantTensorDescriptor<Ts...>;
@@ -290,7 +329,7 @@ struct ConstantTensorDescriptor
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
constexpr index_t fold_intervals_product =
accumulate_on_sequence(fold_intervals, std::multiplies<index_t>{}, Number<1>{});
accumulate_on_sequence(fold_intervals, mod_conv::multiplies<index_t>{}, Number<1>{});
constexpr auto unfold_length = GetLength(Number<IDim>{});
constexpr auto unfold_stride = GetStride(Number<IDim>{});
@@ -309,7 +348,7 @@ struct ConstantTensorDescriptor
constexpr auto fold_strides =
Number<unfold_stride>{} *
reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}),
std::multiplies<index_t>{});
mod_conv::multiplies<index_t>{});
// folded_ranks
constexpr auto fold_ranks =
@@ -389,7 +428,7 @@ struct ConstantTensorDescriptor
// unfolded length, stride and rank
constexpr index_t unfold_length = accumulate_on_sequence(
GetLengths().Extract(middle), std::multiplies<index_t>{}, Number<1>{});
GetLengths().Extract(middle), mod_conv::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
@@ -472,7 +511,20 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{
constexpr index_t ndim = TDesc::GetNumOfDimension();
static_assert(ndim >= 2 && ndim <= 10, "wrong!");
static_assert(ndim >= 1 && ndim <= 10, "wrong!");
static_if<ndim == 1>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u}, strides {%u}, ranks {%u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetStride(I0),
desc.GetMemoryRank(I0));
});
static_if<ndim == 2>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};