mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
added tuple
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
@@ -45,6 +47,7 @@ template <index_t GridSize,
|
||||
index_t OutThreadCopyDataPerAccess_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
{
|
||||
#if 0
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
@@ -478,6 +481,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
#endif
|
||||
});
|
||||
}
|
||||
#else
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
#if 0
|
||||
constexpr auto tmp = std::tuple<bool>{};
|
||||
constexpr auto flag = std::get<0>(tmp);
|
||||
#else
|
||||
constexpr auto a = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("adsas %d\n", a.At(Number<0>{}));
|
||||
print_Sequence("seq", a.At(Number<1>{}));
|
||||
printf("adsas %lu\n", a.At(Number<2>{}));
|
||||
}
|
||||
|
||||
auto b = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
|
||||
|
||||
b.At(Number<0>{}) = false;
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("adsas %d\n", b.At(Number<0>{}));
|
||||
print_Sequence("seq", b.At(Number<1>{}));
|
||||
printf("adsas %lu\n", b.At(Number<2>{}));
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("adsas %d\n",
|
||||
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<0>{}));
|
||||
print_Sequence(
|
||||
"seq", Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<1>{}));
|
||||
printf("adsas %d\n",
|
||||
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<2>{}));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
// create a native tensor descriptor
|
||||
constexpr auto in_n_c_h_w_global_desc =
|
||||
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc);
|
||||
}
|
||||
|
||||
// transform the tensor descriptor once
|
||||
//
|
||||
// calculate the offset of some entry
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -12,15 +12,17 @@ struct Dimension
|
||||
};
|
||||
|
||||
template <index_t Length, index_t Stride>
|
||||
struct NativeDimension : Dimension<Length>
|
||||
struct NativeDimension
|
||||
{
|
||||
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(index_t id) { return id * Stride; }
|
||||
__host__ __device__ static constexpr index_t GetOffset(index_t i) { return i * Stride; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t id_diff)
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t i_diff)
|
||||
{
|
||||
return id_diff * Stride;
|
||||
return i_diff * Stride;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -8,25 +8,19 @@ namespace ck {
|
||||
template <index_t N>
|
||||
using MultiIndex = Array<index_t, N>;
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
template <index_t Length>
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = LowerIndex;
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
return GetNumOfLowerDimension();
|
||||
}
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); }
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; }
|
||||
|
||||
@@ -35,7 +29,7 @@ struct PassThrough
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
@@ -45,25 +39,22 @@ struct Pad
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = LowerIndex;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
return GetNumOfLowerDimension();
|
||||
}
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return GetLowerLengths() + LeftPads + RightPads;
|
||||
return GetLowerLengths() + LeftPads{} + RightPads{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
return idx_up - LeftPads;
|
||||
return idx_up - LeftPads{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
|
||||
@@ -71,9 +62,10 @@ struct Pad
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
#if 0
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
struct Merge
|
||||
@@ -116,8 +108,9 @@ struct Merge
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return false; }
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
};
|
||||
#endif
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
template <index_t LowLength, class UpLengths>
|
||||
@@ -126,6 +119,9 @@ struct Unmerge
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
|
||||
__host__ __device__ constexpr Unmerge()
|
||||
{
|
||||
static_assert(LowLength == accumulate_on_sequence(
|
||||
@@ -133,7 +129,7 @@ struct Unmerge
|
||||
"wrong! UpLengths need to be ");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
@@ -149,7 +145,7 @@ struct Unmerge
|
||||
|
||||
LowerIndex idx_low{0};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) { idx_low[0] += idx_up[idim] * scans[idim]; });
|
||||
static_for<0, nDimUp, 1>{}([&](auto idim) { idx_low(0) += idx_up[idim] * scans[idim]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
@@ -159,7 +155,7 @@ struct Unmerge
|
||||
return GetLowerIndex(idx_up_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
@@ -171,7 +167,8 @@ struct Embed
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
static constexpr auto mCoefficients = Coefficients{};
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ constexpr Embed()
|
||||
{
|
||||
@@ -179,14 +176,14 @@ struct Embed
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
constexpr index_t low_id_max =
|
||||
Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
|
||||
math::plus<index_t>{},
|
||||
Number<0>{});
|
||||
Coefficients::Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
|
||||
math::plus<index_t>{},
|
||||
Number<0>{});
|
||||
|
||||
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
@@ -196,10 +193,10 @@ struct Embed
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
|
||||
{
|
||||
LowerIndex idx_low{mCoefficients[nDimUp]};
|
||||
LowerIndex idx_low(Coefficients{}[nDimUp]);
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low[0] += idx_up[idim] * mCoefficients[idim]; });
|
||||
[&](auto idim) { idx_low[0] += idx_up[idim] * Coefficients{}[idim]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
@@ -209,12 +206,12 @@ struct Embed
|
||||
LowerIndex idx_low_diff{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * mCoefficients[idim]; });
|
||||
[&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * Coefficients{}[idim]; });
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsIndexTransformLinear() { return true; }
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -11,21 +11,39 @@ template <class... NativeDimensions>
|
||||
struct NativeTensorDescriptor
|
||||
{
|
||||
using type = NativeTensorDescriptor;
|
||||
static constexpr auto mDimensions = Tuple<NativeDimensions...>;
|
||||
static constexpr index_t nDim = mDimensions::GetSize();
|
||||
static constexpr auto mDimensions = Tuple<NativeDimensions...>{};
|
||||
static constexpr index_t nDim = mDimensions.GetSize();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
struct lambda_GetLength
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr auto operator()(IDim) const
|
||||
{
|
||||
return GetLength(IDim{});
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
// not implemented
|
||||
return typename sequence_gen<nDim, lambda_GetLength>::type{};
|
||||
}
|
||||
|
||||
struct lambda_GetStride
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr auto operator()(IDim) const
|
||||
{
|
||||
return GetStride(IDim{});
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides()
|
||||
{
|
||||
// not implemented
|
||||
return typename sequence_gen<nDim, lambda_GetStride>::type{};
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
@@ -59,20 +77,26 @@ struct NativeTensorDescriptor
|
||||
return offset_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear();
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
|
||||
{
|
||||
// TODO: re-implement "Sequence", so that it can take other data-type (including bool) as
|
||||
// element
|
||||
return uniform_sequence_gen<nDim, 1>{};
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetIndependentDimensionGroups()
|
||||
__host__ __device__ static constexpr auto GetLinearDimensions()
|
||||
{
|
||||
// not implemented, should return Tuple<Sequence<0>, Sequence<1>, ...>
|
||||
return xxx;
|
||||
return typename arithmetic_sequence_gen<0, nDim, 1>::type{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
// LowerTensorDescriptor
|
||||
// Transforms: std::tuple<DimensionTransforms...>
|
||||
// LowerDimensionIds: std::tuple<Sequence<...>>
|
||||
@@ -213,16 +237,45 @@ struct TransformedTensorDescriptor
|
||||
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear();
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>);
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetIndependentDimensionGroups()
|
||||
__host__ __device__ static constexpr auto GetLinearDimensions()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensions()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>,
|
||||
Sequence<Strides...>)
|
||||
{
|
||||
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
|
||||
}
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths)
|
||||
{
|
||||
constexpr index_t strides = reverse_inclusive_scan_sequence(
|
||||
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
return make_NativeTensorDescriptor(Lengths{}, strides);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class... NativeDimensions>
|
||||
__host__ __device__ void print_tensor_descriptor(const char* s,
|
||||
NativeTensorDescriptor<NativeDimensions...> desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 2>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 3>{}([&](auto) {
|
||||
printf(
|
||||
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 4>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 5>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 6>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 7>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
|
||||
"%u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
|
||||
"%u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -85,6 +85,7 @@ struct TensorVisit
|
||||
{
|
||||
constexpr auto nonlinear_independent_dimensions_igroup =
|
||||
nonlinear_independent_dimension_groups.Get(igroup);
|
||||
|
||||
constexpr auto nonlinear_independent_lengths_igroup =
|
||||
lambda_HackLengths{}(lengths, nonlinear_independent_dimensions_igroup);
|
||||
|
||||
|
||||
@@ -82,9 +82,11 @@ struct Array
|
||||
// A: Array
|
||||
// Picks: Sequence<...>
|
||||
template <class Arr, class Picks>
|
||||
ArrayElementPicker
|
||||
struct ArrayElementPicker
|
||||
{
|
||||
__host__ __device__ constexpr ArrayElementPicker(Arr & array) : mData{array}
|
||||
using data_type = typename Arr::data_type;
|
||||
|
||||
__host__ __device__ constexpr ArrayElementPicker(Arr& array) : mData{array}
|
||||
{
|
||||
constexpr index_t imax =
|
||||
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
|
||||
@@ -95,26 +97,26 @@ ArrayElementPicker
|
||||
__host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData operator[](Number<I>) const
|
||||
__host__ __device__ constexpr data_type operator[](Number<I>) const
|
||||
{
|
||||
constexpr auto IP = Picks::Get(Number<I>{});
|
||||
return mData[IP];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr TData operator[](index_t i) const
|
||||
__host__ __device__ constexpr data_type operator[](index_t i) const
|
||||
{
|
||||
constexpr index_t ip = Picks{}[i];
|
||||
return mData[ip];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ TData& operator()(Number<I>)
|
||||
__host__ __device__ data_type& operator()(Number<I>)
|
||||
{
|
||||
constexpr auto IP = Picks::Get(Number<I>{});
|
||||
return mData[IP];
|
||||
}
|
||||
|
||||
__host__ __device__ TData& operator()(index_t i)
|
||||
__host__ __device__ data_type& operator()(index_t i)
|
||||
{
|
||||
constexpr index_t ip = Picks{}[i];
|
||||
return mData[ip];
|
||||
|
||||
@@ -2,66 +2,99 @@
|
||||
#define CK_TUPLE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "Sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class... Ts>
|
||||
struct tuple : public std::tuple<Ts...>
|
||||
{
|
||||
using type = tuple;
|
||||
namespace detail {
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return std::tuple_size(tuple{}); }
|
||||
template <index_t>
|
||||
struct TupleElementKey
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Key, typename Data>
|
||||
struct TupleElement
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
|
||||
{
|
||||
}
|
||||
|
||||
Data mData;
|
||||
};
|
||||
|
||||
template <typename Key, typename Data>
|
||||
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
|
||||
{
|
||||
return x.mData;
|
||||
}
|
||||
|
||||
template <typename Key, typename Data>
|
||||
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
|
||||
{
|
||||
return x.mData;
|
||||
}
|
||||
|
||||
template <typename Key, typename Data>
|
||||
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
|
||||
{
|
||||
return static_cast<Data&&>(x.mData);
|
||||
}
|
||||
|
||||
template <typename Indices, typename... Xs>
|
||||
struct TupleImpl;
|
||||
|
||||
template <index_t... Is, typename... Xs>
|
||||
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
|
||||
{
|
||||
template <typename... Ys>
|
||||
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Get(Number<I>) const
|
||||
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
|
||||
{
|
||||
return std::get<I>(*this);
|
||||
return get_tuple_element<TupleElementKey<I>>(*this);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I>) const
|
||||
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
|
||||
{
|
||||
return Get(Number<I>{}) :
|
||||
return get_tuple_element<TupleElementKey<I>>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
// merge tuple
|
||||
template <class... Tuples>
|
||||
__host__ __device__ constexpr auto merge_tuple(Tuples&&... xs)
|
||||
{
|
||||
return std::tuple_cat(xs...);
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// generate sequence
|
||||
template <index_t IBegin, index_t NRemain, class F>
|
||||
struct tuple_gen_impl
|
||||
template <typename... Xs>
|
||||
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
|
||||
static constexpr index_t IMiddle = IBegin + NRemainLeft;
|
||||
using base =
|
||||
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
|
||||
|
||||
using type =
|
||||
typename tuple_merge<typename tuple_gen_impl<IBegin, NRemainLeft, F>::type,
|
||||
typename tuple_gen_impl<IMiddle, NRemainRight, F>::type>::type;
|
||||
};
|
||||
template <typename... Ys>
|
||||
__host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast<Ys&&>(ys)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t I, class F>
|
||||
struct tuple_gen_impl<I, 1, F>
|
||||
{
|
||||
static constexpr auto x = F{}(Number<I>{});
|
||||
using type = tuple<Is>;
|
||||
};
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& At(Number<I>) const
|
||||
{
|
||||
static_assert(I < base::Size(), "wrong! out of range");
|
||||
return GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
|
||||
template <index_t I, class F>
|
||||
struct sequence_gen_impl<I, 0, F>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
template <index_t NSize, class F>
|
||||
struct sequence_gen
|
||||
{
|
||||
using type = typename sequence_gen_impl<0, NSize, F>::type;
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& At(Number<I>)
|
||||
{
|
||||
static_assert(I < base::Size(), "wrong! out of range");
|
||||
return GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -65,9 +65,6 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
|
||||
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
|
||||
auto f = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
|
||||
@@ -125,9 +122,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
|
||||
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
|
||||
|
||||
@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
|
||||
#if 0
|
||||
device_convolution_direct_v2_nchw_kcyx_nkhw
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(
|
||||
in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 1
|
||||
|
||||
Reference in New Issue
Block a user