added implicit gemm v1r3 lds_double_buffer NCHW * CYXK = KNHW, reworked static functionals

[ROCm/composable_kernel commit: 569ad66e2a]
This commit is contained in:
Chao Liu
2019-04-23 17:51:14 -05:00
parent fe17969c81
commit 21988c32b4
22 changed files with 2117 additions and 1107 deletions

View File

@@ -1,80 +1,30 @@
#pragma once
#include "common.hip.hpp"
// this is ugly, only for 2d
template <index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
template <class PreviousStrides, class RemainLengths>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
{
return Sequence<L1, 1>{};
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_length = RemainLengths{}.Back();
constexpr index_t current_stride = current_length * previous_stride;
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
RemainLengths{}.PopBack());
}
// this is ugly, only for 3d
template <index_t L0, index_t L1, index_t L2>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2>)
template <class PreviousStrides, index_t L0, index_t L1>
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
{
return Sequence<L1 * L2, L2, 1>{};
constexpr index_t previous_stride = PreviousStrides{}.Front();
constexpr index_t current_stride = L1 * previous_stride;
return PreviousStrides{}.PushFront(Number<current_stride>{});
}
// this is ugly, only for 4d
template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
template <class Lengths>
__host__ __device__ constexpr auto calculate_default_strides(Lengths)
{
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}
// this is ugly, only for 6d
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
{
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
}
// this is ugly, only for 8d
template <index_t L0,
index_t L1,
index_t L2,
index_t L3,
index_t L4,
index_t L5,
index_t L6,
index_t L7>
__host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
{
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7,
L2 * L3 * L4 * L5 * L6 * L7,
L3 * L4 * L5 * L6 * L7,
L4 * L5 * L6 * L7,
L5 * L6 * L7,
L6 * L7,
L7,
1>{};
}
// this is ugly, only for 8d
template <index_t L0,
index_t L1,
index_t L2,
index_t L3,
index_t L4,
index_t L5,
index_t L6,
index_t L7,
index_t L8,
index_t L9>
__host__ __device__ constexpr auto
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7, L8, L9>)
{
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
L3 * L4 * L5 * L6 * L7 * L8 * L9,
L4 * L5 * L6 * L7 * L8 * L9,
L5 * L6 * L7 * L8 * L9,
L6 * L7 * L8 * L9,
L7 * L8 * L9,
L8 * L9,
L9,
1>{};
return calculate_default_strides_impl(Sequence<1>{}, Lengths{});
}
// this is ugly, only for 2d
@@ -186,6 +136,14 @@ struct ConstantTensorDescriptor
return Get1dIndex(multi_id);
}
template <index_t... Is>
__host__ __device__ static constexpr index_t Get1dIndex(Sequence<Is...> multi_id)
{
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
return Get1dIndex(Is...);
}
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
{
Array<index_t, nDim> multi_id;