mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
adding implicit gemm v3
This commit is contained in:
@@ -2,20 +2,20 @@
|
||||
#include "common.hip.hpp"
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_packed_tensor_strides(Lengths)
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_default_rank_packed(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), std::multiplies<index_t>{})
|
||||
.PushBack(Number<1>{});
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_rank_tensor_default_strides_with_alignment(Lengths, Number<Align>)
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_default_rank_aligned(Lengths,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr index_t L_back_align =
|
||||
Align * mod_conv::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
|
||||
|
||||
return calculate_packed_tensor_strides(
|
||||
return calculate_tensor_strides_default_rank_packed(
|
||||
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
|
||||
}
|
||||
|
||||
@@ -66,6 +66,12 @@ struct ConstantTensorDescriptor
|
||||
return MemoryRanks{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return accumulate_on_sequence(Lengths{}, std::multiplies<index_t>{}, Number<1>{});
|
||||
@@ -146,7 +152,7 @@ struct ConstantTensorDescriptor
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
constexpr auto dummy_strides = calculate_packed_tensor_strides(GetLengths());
|
||||
constexpr auto dummy_strides = calculate_tensor_strides_default_rank_packed(GetLengths());
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension (not rank)
|
||||
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
|
||||
@@ -181,6 +187,12 @@ struct ConstantTensorDescriptor
|
||||
return ConstantTensorDescriptor<extract_lengths, extract_strides, new_ranks>{};
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto Extract(Sequence<IDims...>)
|
||||
{
|
||||
return Extract(Number<IDims>{}...);
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
@@ -271,9 +283,11 @@ struct ConstantTensorDescriptor
|
||||
FirstUnfoldDim <= LastUnfoldDim,
|
||||
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
|
||||
|
||||
#if 0 // cannot compile: compiler complain about constexpr
|
||||
// dimensions to be unfold need to be in descending order (w.r.t. strides), and need to be
|
||||
// packed in memory, otherwise, unfolding is invalid
|
||||
static_for<FirstUnfoldDim, LastUnfoldDim, 1>{}([&](auto IDim) {
|
||||
static_for<FirstUnfoldDim, LastUnfoldDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
constexpr auto IDim_p1 = IDim + Number<1>{};
|
||||
|
||||
// check stride
|
||||
@@ -285,11 +299,12 @@ struct ConstantTensorDescriptor
|
||||
static_assert(GetStride(IDim_p1) * GetLength(IDim_p1) == GetStride(IDim),
|
||||
"wrong! dimensions to be unfolded need to be packed");
|
||||
|
||||
// checkt ranks
|
||||
// check ranks
|
||||
static_assert(GetMemoryRank(IDim_p1) == GetMemoryRank(IDim) + 1,
|
||||
"wrong! ranks of dimensions to be unfolded need to be in increasing and "
|
||||
"continuous ranks");
|
||||
});
|
||||
#endif
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{};
|
||||
@@ -308,9 +323,9 @@ struct ConstantTensorDescriptor
|
||||
|
||||
// decrease the ranks that are larger than the rank of LastUnfoldDim
|
||||
constexpr auto tmp_ranks =
|
||||
transform_sequences(GetMemoryRanks(),
|
||||
f_unfold_impl<GetMemoryRank(Number<LastUnfoldDim>{}),
|
||||
LastUnfoldDim - FirstUnfoldDim + 1>{});
|
||||
transform_sequences(f_unfold_impl<GetMemoryRank(Number<LastUnfoldDim>{}),
|
||||
LastUnfoldDim - FirstUnfoldDim + 1>{},
|
||||
GetMemoryRanks());
|
||||
|
||||
// new lengths, strides and ranks
|
||||
constexpr auto new_lengths = GetLengths()
|
||||
@@ -354,26 +369,26 @@ struct ConstantTensorDescriptor
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_packed_ConstantTensorDescriptor(Lengths)
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_packed(Lengths)
|
||||
{
|
||||
using Strides = decltype(calculate_packed_tensor_strides(Lengths{}));
|
||||
using Strides = decltype(calculate_tensor_strides_default_rank_packed(Lengths{}));
|
||||
using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType;
|
||||
return ConstantTensorDescriptor<Lengths, Strides, MemoryRanks>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
__host__ __device__ constexpr auto make_ranked_ConstantTensorDescriptor(Lengths, Strides)
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank(Lengths, Strides)
|
||||
{
|
||||
using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType;
|
||||
return ConstantTensorDescriptor<Lengths, Strides, MemoryRanks>{};
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ranked_ConstantTensorDescriptor_with_alignment(Lengths, Number<Align>)
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_aligned(Lengths,
|
||||
Number<Align>)
|
||||
{
|
||||
using Strides =
|
||||
decltype(calculate_rank_tensor_default_strides_with_alignment(Lengths{}, Number<Align>{}));
|
||||
decltype(calculate_tensor_strides_default_rank_aligned(Lengths{}, Number<Align>{}));
|
||||
using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType;
|
||||
return ConstantTensorDescriptor<Lengths, Strides, MemoryRanks>{};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user