mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
adding implicit gemm v3
This commit is contained in:
@@ -11,8 +11,8 @@ struct ConstantMergedTensorDescriptor
|
||||
{
|
||||
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
|
||||
|
||||
static constexpr index_t nDim = std::tuple_size<mOriginalDimMergeSeqs>::value;
|
||||
static constexpr index_t nOriginalDim = OriginalDesc::GetNumOfDimension();
|
||||
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
|
||||
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
|
||||
|
||||
__host__ __device__ constexpr ConstantMergedTensorDescriptor()
|
||||
{
|
||||
@@ -21,25 +21,28 @@ struct ConstantMergedTensorDescriptor
|
||||
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
|
||||
// OriginalTensorDesc::nDim number of dimensions
|
||||
|
||||
// TODO: check there is no duplication in OriginalDimMergeSeqs
|
||||
|
||||
// TODO: check OriginalDimMergeSeqs contains all original dimensions
|
||||
|
||||
// TODO: check there is no duplication in OriginalDimMergeSeqs
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension() { return nOriginalDim }
|
||||
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
|
||||
{
|
||||
return nOriginalDim;
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return (std::Get<IDIM>(mOriginalDimMergeSeqs).GetSize() > 1);
|
||||
return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr index_t GetLength(Number<IDim>)
|
||||
{
|
||||
constexpr auto original_dims_partial = std::Get<IDim>(mOriginalDimMergeSeqs);
|
||||
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
|
||||
|
||||
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
|
||||
}
|
||||
@@ -50,14 +53,14 @@ struct ConstantMergedTensorDescriptor
|
||||
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
|
||||
"wrong! stride of a merged dimension is undefined");
|
||||
|
||||
constexpr auto idim_original = std::Get<IDim>(mOriginalDimMergeSeqs).Front();
|
||||
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
|
||||
|
||||
return OriginalTensorDesc::GetStride(Number<idim_original>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs).GetElementSize()...>{};
|
||||
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
@@ -75,17 +78,16 @@ struct ConstantMergedTensorDescriptor
|
||||
constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);
|
||||
|
||||
// get partial original-multi-id corresponding to this merged dimension
|
||||
constexpr auto original_multi_id_partial =
|
||||
const auto original_multi_id_partial =
|
||||
OriginalTensorDesc::Extract(original_dims_partial)
|
||||
.GetMultiIndexFrom1dIndex(multi_id[idim]);
|
||||
|
||||
// make sure compiler unroll this loop and propagate all the constants
|
||||
for(index_t i = 0; i < original_dims_partial.GetSize(); ++i)
|
||||
{
|
||||
index_t idim_original = original_dims_partial[i];
|
||||
static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
|
||||
constexpr auto I = decltype(I_){};
|
||||
constexpr index_t idim_original = original_dims_partial.Get(I);
|
||||
|
||||
original_multi_id[idim_original] = original_multi_id_partial[i]
|
||||
}
|
||||
original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
|
||||
});
|
||||
});
|
||||
|
||||
return original_multi_id;
|
||||
@@ -95,10 +97,10 @@ struct ConstantMergedTensorDescriptor
|
||||
{
|
||||
const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(orginal_multi_id);
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
|
||||
@@ -106,14 +108,15 @@ struct ConstantMergedTensorDescriptor
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
constexpr auto dummy_desc = make_packed_ConstantTensorDescriptor(GetLengths());
|
||||
constexpr auto dummy_desc = make_ConstantTensorDescriptor_default_rank_packed(GetLengths());
|
||||
|
||||
return dummy_desc.GetMultiIndexFrom1dIndex(id);
|
||||
}
|
||||
};
|
||||
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...)
|
||||
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
|
||||
OriginalDimMergeSeqs...)
|
||||
{
|
||||
return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user