mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
implicit gemm v1r2: adding support for nchw
This commit is contained in:
@@ -108,11 +108,11 @@ template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr index_t nDim = Lengths::nDim;
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
{
|
||||
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
|
||||
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
|
||||
@@ -157,12 +157,10 @@ struct ConstantTensorDescriptor
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t Get1dIndex(Is... is)
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static index_t Get1dIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
const auto multi_id = Array<index_t, nDim>(is...);
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
index_t id = 0;
|
||||
|
||||
@@ -178,6 +176,16 @@ struct ConstantTensorDescriptor
|
||||
return id;
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t Get1dIndex(Is... is)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
const auto multi_id = Array<index_t, nDim>(is...);
|
||||
|
||||
return Get1dIndex(multi_id);
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
Reference in New Issue
Block a user