adding implicit gemm v3

This commit is contained in:
Chao Liu
2019-05-23 22:10:40 -05:00
parent 8a4b59785b
commit 1cc683a3a3
16 changed files with 347 additions and 95 deletions

View File

@@ -396,31 +396,35 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_al
template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
{
constexpr auto desc = TDesc{};
constexpr index_t ndim = desc.GetNumOfDimension();
constexpr index_t ndim = TDesc::GetNumOfDimension();
static_assert(ndim >= 2 && ndim <= 10, "wrong!");
if(ndim == 2)
{
static_if<ndim == 2>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n",
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u}, strides {%u %u}, ranks {%u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetStride(I0),
desc.GetStride(I1));
}
else if(ndim == 3)
{
desc.GetStride(I1),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1));
});
static_if<ndim == 3>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}, ranks {%u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -428,16 +432,21 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetLength(I2),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2));
}
else if(ndim == 4)
{
desc.GetStride(I2),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2));
});
static_if<ndim == 4>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}, ranks {%u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -447,17 +456,24 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3));
}
else if(ndim == 5)
{
desc.GetStride(I3),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3));
});
static_if<ndim == 5>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}, ranks {%u %u %u %u "
"%u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -469,10 +485,15 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4));
}
else if(ndim == 6)
{
desc.GetStride(I4),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4));
});
static_if<ndim == 6>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
@@ -480,7 +501,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}, ranks {%u %u "
"%u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -494,10 +518,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I2),
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetStride(I5));
}
else if(ndim == 7)
{
desc.GetStride(I5),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5));
});
static_if<ndim == 7>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
@@ -506,7 +536,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}, ranks "
"{%u %u %u %u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -522,10 +555,17 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I3),
desc.GetStride(I4),
desc.GetStride(I5),
desc.GetStride(I6));
}
else if(ndim == 8)
{
desc.GetStride(I6),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6));
});
static_if<ndim == 8>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
@@ -535,7 +575,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}, "
"ranks {%u %u %u %u %u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -553,10 +596,18 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I4),
desc.GetStride(I5),
desc.GetStride(I6),
desc.GetStride(I7));
}
else if(ndim == 9)
{
desc.GetStride(I7),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7));
});
static_if<ndim == 9>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
@@ -567,8 +618,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}\n",
"%u}, ranks {%u %u %u %u %u %u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -588,10 +641,19 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I5),
desc.GetStride(I6),
desc.GetStride(I7),
desc.GetStride(I8));
}
else if(ndim == 10)
{
desc.GetStride(I8),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7),
desc.GetMemoryRank(I8));
});
static_if<ndim == 10>{}([&](auto fwd) {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
@@ -603,8 +665,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr auto I8 = Number<8>{};
constexpr auto I9 = Number<9>{};
constexpr auto desc = fwd(TDesc{});
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}\n",
"%u %u %u}, ranks {%u %u %u %u %u %u %u %u %u %u}\n",
s,
desc.GetNumOfDimension(),
desc.GetLength(I0),
@@ -626,6 +690,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I6),
desc.GetStride(I7),
desc.GetStride(I8),
desc.GetStride(I9));
}
desc.GetStride(I9),
desc.GetMemoryRank(I0),
desc.GetMemoryRank(I1),
desc.GetMemoryRank(I2),
desc.GetMemoryRank(I3),
desc.GetMemoryRank(I4),
desc.GetMemoryRank(I5),
desc.GetMemoryRank(I6),
desc.GetMemoryRank(I7),
desc.GetMemoryRank(I8),
desc.GetMemoryRank(I9));
});
}