adding implicit gemm v3

This commit is contained in:
Chao Liu
2019-05-15 09:58:17 -05:00
parent 4957d5a399
commit b7d052459d
29 changed files with 977 additions and 296 deletions

View File

@@ -65,7 +65,7 @@ struct ConstantTensorDescriptor
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
}
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
__host__ __device__ static constexpr Lengths GetLengths() { return Lengths{}; }
@@ -160,11 +160,51 @@ struct ConstantTensorDescriptor
return multi_id;
}
__host__ __device__ static constexpr auto Condense()
__host__ __device__ static constexpr auto Pack()
{
constexpr auto default_strides = calculate_default_strides(Lengths{});
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
}
template <index_t IDims...>
__host__ __device__ static constexpr auto Extract(Number<IDims>... /*extracted_dims...*/)
{
static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong!");
constexpr auto extracted_lengths = Sequence<Lengths{}.Get(Number<IDims>{})...>{};
constexpr auto extracted_strides = Sequence<Strides{}.Get(Number<IDims>{})...>{};
return make_ConstantTensorDescriptor(extracted_lenghts, extracted_strides);
}
template <index_t IDim, index_t SliceLen>
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
{
// not implemented
}
template <index_t IDim, index_t... FoldLengths>
__host__ device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldLengths...>)
{
// not implemented
// need to check the Length dimension to be folded is dividable by FoldLengths
}
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
{
// not implemented
// need to check the dimensions to be unfold are packed, otherwise, Unfold is not permitted
}
template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(IRs) == GetNumberOfDimension(), "wrong! dimension is wrong");
constexpr auto map_new2old = Sequence<IRs...>{};
return make_ConstantTensorDescriptor(Lengths{}.ReorderGivenNew2Old(map_new2old),
Strides{}.ReorderGivenNew2Old(map_new2old));
}
};
template <class Lengths>
@@ -191,7 +231,7 @@ template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{
constexpr auto desc = TDesc{};
constexpr index_t ndim = desc.GetDimension();
constexpr index_t ndim = desc.GetNumOfDimension();
static_assert(ndim >= 2 && ndim <= 10, "wrong!");
@@ -202,7 +242,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetStride(I0),
@@ -216,7 +256,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
@@ -233,7 +273,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
@@ -253,7 +293,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
@@ -276,7 +316,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
@@ -302,7 +342,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
@@ -331,7 +371,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
@@ -364,7 +404,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
@@ -400,7 +440,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}\n",
s,
desc.GetDimension(),
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),