mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
adding implicit gemm v3
This commit is contained in:
@@ -396,31 +396,35 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_al
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
{
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr index_t ndim = desc.GetNumOfDimension();
|
||||
constexpr index_t ndim = TDesc::GetNumOfDimension();
|
||||
|
||||
static_assert(ndim >= 2 && ndim <= 10, "wrong!");
|
||||
|
||||
if(ndim == 2)
|
||||
{
|
||||
static_if<ndim == 2>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n",
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}, ranks {%u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1));
|
||||
}
|
||||
else if(ndim == 3)
|
||||
{
|
||||
desc.GetStride(I1),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1));
|
||||
});
|
||||
|
||||
static_if<ndim == 3>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}, ranks {%u %u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -428,16 +432,21 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetLength(I2),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2));
|
||||
}
|
||||
else if(ndim == 4)
|
||||
{
|
||||
desc.GetStride(I2),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2));
|
||||
});
|
||||
|
||||
static_if<ndim == 4>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}, ranks {%u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -447,17 +456,24 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3));
|
||||
}
|
||||
else if(ndim == 5)
|
||||
{
|
||||
desc.GetStride(I3),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2),
|
||||
desc.GetMemoryRank(I3));
|
||||
});
|
||||
|
||||
static_if<ndim == 5>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}, ranks {%u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -469,10 +485,15 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4));
|
||||
}
|
||||
else if(ndim == 6)
|
||||
{
|
||||
desc.GetStride(I4),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2),
|
||||
desc.GetMemoryRank(I3),
|
||||
desc.GetMemoryRank(I4));
|
||||
});
|
||||
|
||||
static_if<ndim == 6>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -480,7 +501,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}, ranks {%u %u "
|
||||
"%u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -494,10 +518,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5));
|
||||
}
|
||||
else if(ndim == 7)
|
||||
{
|
||||
desc.GetStride(I5),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2),
|
||||
desc.GetMemoryRank(I3),
|
||||
desc.GetMemoryRank(I4),
|
||||
desc.GetMemoryRank(I5));
|
||||
});
|
||||
|
||||
static_if<ndim == 7>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -506,7 +536,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}, ranks "
|
||||
"{%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -522,10 +555,17 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5),
|
||||
desc.GetStride(I6));
|
||||
}
|
||||
else if(ndim == 8)
|
||||
{
|
||||
desc.GetStride(I6),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2),
|
||||
desc.GetMemoryRank(I3),
|
||||
desc.GetMemoryRank(I4),
|
||||
desc.GetMemoryRank(I5),
|
||||
desc.GetMemoryRank(I6));
|
||||
});
|
||||
|
||||
static_if<ndim == 8>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -535,7 +575,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}, "
|
||||
"ranks {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -553,10 +596,18 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5),
|
||||
desc.GetStride(I6),
|
||||
desc.GetStride(I7));
|
||||
}
|
||||
else if(ndim == 9)
|
||||
{
|
||||
desc.GetStride(I7),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2),
|
||||
desc.GetMemoryRank(I3),
|
||||
desc.GetMemoryRank(I4),
|
||||
desc.GetMemoryRank(I5),
|
||||
desc.GetMemoryRank(I6),
|
||||
desc.GetMemoryRank(I7));
|
||||
});
|
||||
|
||||
static_if<ndim == 9>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -567,8 +618,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
"%u}, ranks {%u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -588,10 +641,19 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I5),
|
||||
desc.GetStride(I6),
|
||||
desc.GetStride(I7),
|
||||
desc.GetStride(I8));
|
||||
}
|
||||
else if(ndim == 10)
|
||||
{
|
||||
desc.GetStride(I8),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2),
|
||||
desc.GetMemoryRank(I3),
|
||||
desc.GetMemoryRank(I4),
|
||||
desc.GetMemoryRank(I5),
|
||||
desc.GetMemoryRank(I6),
|
||||
desc.GetMemoryRank(I7),
|
||||
desc.GetMemoryRank(I8));
|
||||
});
|
||||
|
||||
static_if<ndim == 10>{}([&](auto fwd) {
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -603,8 +665,10 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
constexpr auto I8 = Number<8>{};
|
||||
constexpr auto I9 = Number<9>{};
|
||||
|
||||
constexpr auto desc = fwd(TDesc{});
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
"%u %u %u}, ranks {%u %u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
@@ -626,6 +690,16 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I6),
|
||||
desc.GetStride(I7),
|
||||
desc.GetStride(I8),
|
||||
desc.GetStride(I9));
|
||||
}
|
||||
desc.GetStride(I9),
|
||||
desc.GetMemoryRank(I0),
|
||||
desc.GetMemoryRank(I1),
|
||||
desc.GetMemoryRank(I2),
|
||||
desc.GetMemoryRank(I3),
|
||||
desc.GetMemoryRank(I4),
|
||||
desc.GetMemoryRank(I5),
|
||||
desc.GetMemoryRank(I6),
|
||||
desc.GetMemoryRank(I7),
|
||||
desc.GetMemoryRank(I8),
|
||||
desc.GetMemoryRank(I9));
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user