adding implicit gemm v3

This commit is contained in:
Chao Liu
2019-05-22 19:39:56 -05:00
parent 2a48812edb
commit 8a4b59785b
26 changed files with 373 additions and 259 deletions

View File

@@ -11,8 +11,8 @@ struct ConstantMergedTensorDescriptor
{
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
static constexpr index_t nDim = std::tuple_size<mOriginalDimMergeSeqs>::value;
static constexpr index_t nOriginalDim = OriginalDesc::GetNumOfDimension();
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
__host__ __device__ constexpr ConstantMergedTensorDescriptor()
{
@@ -21,25 +21,28 @@ struct ConstantMergedTensorDescriptor
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
// OriginalTensorDesc::nDim number of dimensions
// TODO: check there is no duplication in OriginalDimMergeSeqs
// TODO: check OriginalDimMergeSeqs contains all original dimensions
// TODO: check there is no duplication in OriginalDimMergeSeqs
}
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension() { return nOriginalDim }
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
{
return nOriginalDim;
}
template <index_t IDim>
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
{
return (std::Get<IDIM>(mOriginalDimMergeSeqs).GetSize() > 1);
return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
}
template <index_t IDim>
__host__ __device__ static constexpr index_t GetLength(Number<IDim>)
{
constexpr auto original_dims_partial = std::Get<IDim>(mOriginalDimMergeSeqs);
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
}
@@ -50,14 +53,14 @@ struct ConstantMergedTensorDescriptor
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::Get<IDim>(mOriginalDimMergeSeqs).Front();
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
return OriginalTensorDesc::GetStride(Number<idim_original>{});
}
__host__ __device__ static constexpr auto GetLengths()
{
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs).GetElementSize()...>{};
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
}
__host__ __device__ static constexpr index_t GetElementSize()
@@ -75,17 +78,16 @@ struct ConstantMergedTensorDescriptor
constexpr auto original_dims_partial = std::get<idim>(mOriginalDimMergeSeqs);
// get partial original-multi-id corresponding to this merged dimension
constexpr auto original_multi_id_partial =
const auto original_multi_id_partial =
OriginalTensorDesc::Extract(original_dims_partial)
.GetMultiIndexFrom1dIndex(multi_id[idim]);
// make sure compiler unroll this loop and propagate all the constants
for(index_t i = 0; i < original_dims_partial.GetSize(); ++i)
{
index_t idim_original = original_dims_partial[i];
static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) {
constexpr auto I = decltype(I_){};
constexpr index_t idim_original = original_dims_partial.Get(I);
original_multi_id[idim_original] = original_multi_id_partial[i]
}
original_multi_id[idim_original] = original_multi_id_partial[I.Get()];
});
});
return original_multi_id;
@@ -95,10 +97,10 @@ struct ConstantMergedTensorDescriptor
{
const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(orginal_multi_id);
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
}
template <index_t... Is>
template <class... Is>
__host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is)
{
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
@@ -106,14 +108,15 @@ struct ConstantMergedTensorDescriptor
__host__ __device__ static Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
{
constexpr auto dummy_desc = make_packed_ConstantTensorDescriptor(GetLengths());
constexpr auto dummy_desc = make_ConstantTensorDescriptor_default_rank_packed(GetLengths());
return dummy_desc.GetMultiIndexFrom1dIndex(id);
}
};
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...)
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
OriginalDimMergeSeqs...)
{
return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
}