mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
adding implicit gemm v3
This commit is contained in:
@@ -65,7 +65,7 @@ struct ConstantTensorDescriptor
|
||||
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ static constexpr Lengths GetLengths() { return Lengths{}; }
|
||||
|
||||
@@ -160,11 +160,51 @@ struct ConstantTensorDescriptor
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Condense()
|
||||
__host__ __device__ static constexpr auto Pack()
|
||||
{
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <index_t IDims...>
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... /*extracted_dims...*/)
|
||||
{
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong!");
|
||||
|
||||
constexpr auto extracted_lengths = Sequence<Lengths{}.Get(Number<IDims>{})...>{};
|
||||
constexpr auto extracted_strides = Sequence<Strides{}.Get(Number<IDims>{})...>{};
|
||||
|
||||
return make_ConstantTensorDescriptor(extracted_lenghts, extracted_strides);
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldLengths>
|
||||
__host__ device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldLengths...>)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the Length dimension to be folded is dividable by FoldLengths
|
||||
}
|
||||
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the dimensions to be unfold are packed, otherwise, Unfold is not permitted
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(IRs) == GetNumberOfDimension(), "wrong! dimension is wrong");
|
||||
constexpr auto map_new2old = Sequence<IRs...>{};
|
||||
return make_ConstantTensorDescriptor(Lengths{}.ReorderGivenNew2Old(map_new2old),
|
||||
Strides{}.ReorderGivenNew2Old(map_new2old));
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
@@ -191,7 +231,7 @@ template <class TDesc>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
{
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr index_t ndim = desc.GetDimension();
|
||||
constexpr index_t ndim = desc.GetNumOfDimension();
|
||||
|
||||
static_assert(ndim >= 2 && ndim <= 10, "wrong!");
|
||||
|
||||
@@ -202,7 +242,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetStride(I0),
|
||||
@@ -216,7 +256,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -233,7 +273,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -253,7 +293,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -276,7 +316,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -302,7 +342,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -331,7 +371,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -364,7 +404,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -400,7 +440,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
|
||||
Reference in New Issue
Block a user