mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
adding implicit gemm v4 (nchw, kcyx)
This commit is contained in:
@@ -40,6 +40,14 @@ struct ConstantTensorDescriptor
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return Sequence<IDim>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
|
||||
@@ -66,6 +74,19 @@ struct ConstantTensorDescriptor
|
||||
return MemoryRanks{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool AreStridesNonAscending()
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
|
||||
constexpr auto IDim_p1 = Number<IDim.Get() + 1>{};
|
||||
|
||||
flag = flag && (GetLength(IDim) >= GetLength(IDim_p1));
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
|
||||
{
|
||||
@@ -167,6 +188,46 @@ struct ConstantTensorDescriptor
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
// This function doesn't do carry check on the highest dimension, for performance reason.
|
||||
// It is the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound
|
||||
// on the highest dimension
|
||||
__host__ __device__ static Array<index_t, nDim>
|
||||
UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
|
||||
index_t step_size_of_1d_index)
|
||||
{
|
||||
auto new_multi_id = old_multi_id + GetMultiIndexFrom1dIndex(step_size_of_1d_index);
|
||||
|
||||
bool carry = false;
|
||||
|
||||
// do carry check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) {
|
||||
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
|
||||
constexpr auto IDim = Number<idim>{};
|
||||
|
||||
if(carry)
|
||||
{
|
||||
++new_multi_id[idim];
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(new_multi_id[idim] >= GetLength(IDim))
|
||||
{
|
||||
new_multi_id[idim] -= GetLength(IDim);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
|
||||
return new_multi_id;
|
||||
}
|
||||
|
||||
// WRONG! Ranks is broken
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
|
||||
@@ -193,6 +254,19 @@ struct ConstantTensorDescriptor
|
||||
return Extract(Number<IDims>{}...);
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto Inject(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
using leaf_tensor = ConstantTensorDescriptor<Ts...>;
|
||||
|
||||
// memory rank is broken
|
||||
// TODO: remove memory rank info from tensor descritpor
|
||||
return ConstantTensorDescriptor<decltype(GetLengths().Append(leaf_tensor::GetLengths())),
|
||||
decltype(GetStrides().Append(leaf_tensor::GetStrides())),
|
||||
decltype(GetMemoryRanks().Append(
|
||||
leaf_tensor::GetMemoryRanks()))>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user