mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
reimplement threadwise copy
This commit is contained in:
@@ -157,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
NormalTensorCoordinate<decltype(in_c_h_w_n_global_desc)>,
|
||||
|
||||
@@ -176,7 +176,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
#else
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
MergedTensorCoordinate<decltype(in_e_n1_b_n2_global_merged_desc)>,
|
||||
@@ -219,7 +218,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
#else
|
||||
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
|
||||
@@ -373,7 +371,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
Number<1>{});
|
||||
#else
|
||||
ThreadwiseGenericTensorSliceCopy_v2<
|
||||
Float,
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
|
||||
NormalTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc)>,
|
||||
|
||||
@@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
Float,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
|
||||
@@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
|
||||
@@ -288,7 +286,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
|
||||
|
||||
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
|
||||
Float,
|
||||
decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
|
||||
|
||||
@@ -131,7 +131,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
Float,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
|
||||
@@ -158,7 +157,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
|
||||
@@ -352,7 +350,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
|
||||
|
||||
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
|
||||
Float,
|
||||
decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
|
||||
|
||||
@@ -65,11 +65,21 @@ struct ConstantMergedTensorDescriptor
|
||||
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
|
||||
"wrong! stride of a merged dimension is undefined");
|
||||
|
||||
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
|
||||
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
|
||||
|
||||
return OriginalTensorDesc::GetStride(Number<idim_original>{});
|
||||
}
|
||||
|
||||
// this is a hack to return the stride of the last original dimension of a merged dimension
|
||||
// TODO: refactor this once the concept of "dimension" is used
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
|
||||
{
|
||||
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
|
||||
|
||||
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst.
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// its sub-length need to evenly divide the length of the last original dimension
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
// This functions assume each thread is reading and writing a normal (not merged) tensor,
|
||||
// to simplify index calculations. To satisfy this assumption, the user need to make sure
|
||||
// that, on a merged dimension that constains multiple original dimensions, the length of
|
||||
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
|
||||
// repeat-length on the merged dimension need to be 1.
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
@@ -88,30 +90,55 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into sub-tensor");
|
||||
|
||||
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into cluster");
|
||||
});
|
||||
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// its sub-length need to evenly divide the length of the last original dimension,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto sub_length = SubLengths::Get(IDim);
|
||||
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto idim_original_src = SrcDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_original_src) %
|
||||
sub_length ==
|
||||
0,
|
||||
"wrong!");
|
||||
// additional check for merged dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
// src
|
||||
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
constexpr auto idim_original_dst = DstDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(DstDesc::GetOriginalTensorDescriptor().GetLength(idim_original_dst) %
|
||||
sub_length ==
|
||||
0,
|
||||
"wrong!");
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// the length of the last original dimension need to evenly dividable by its
|
||||
// sub-length,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
constexpr auto idim_last_original_src =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(
|
||||
SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) %
|
||||
SubLengths::Get(IDim) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
// merged dimension should have repeat_lengths = 1
|
||||
static_assert(repeat_lengths[IDim] == 1,
|
||||
"wrong! repeat_lengths shoud be 1 on merged dimension");
|
||||
});
|
||||
|
||||
// dst
|
||||
static_if<DstDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// the length of the last original dimension need to evenly dividable by its
|
||||
// sub-length,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
constexpr auto idim_last_original_dst =
|
||||
DstDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(
|
||||
DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) %
|
||||
SubLengths::Get(IDim) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
// merged dimension should have repeat_lengths = 1
|
||||
static_assert(repeat_lengths[IDim] == 1,
|
||||
"wrong! repeat_lengths shoud be 1 on merged dimension");
|
||||
});
|
||||
});
|
||||
|
||||
// calculate mThreadSrcOffset, mThreadDstOffset
|
||||
@@ -376,7 +403,6 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class TData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcCoordinate,
|
||||
@@ -428,16 +454,19 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
return RegisterBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
|
||||
{
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
|
||||
{
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
TData p_buffer[GetRegisterBufferSize()];
|
||||
@@ -466,16 +495,14 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
SrcDesc,
|
||||
ThreadwiseGenericTensorSliceCopy_v2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SrcCoordinate,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
SubLengths>;
|
||||
|
||||
using ThreadwiseStore =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
RegisterBufferDesc,
|
||||
ThreadwiseGenericTensorSliceCopy_v2<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
DstCoordinate,
|
||||
|
||||
@@ -106,8 +106,107 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class TData,
|
||||
class SrcDesc,
|
||||
#if 0
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetNumOfDimension();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_slice_origin,
|
||||
Array<index_t, nDim> dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() &&
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::{} &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::{},
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(SliceLengths{}[SrcVectorDim] % SrcDataPerAccess == 0 &&
|
||||
SliceLengths{DstVectorDim} % DstDataPerAccess == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// check vectorized memory access
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDIm>{};
|
||||
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(
|
||||
src_vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStrides()[SrcVectorAccessDim] == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else{}([&](auto fwd) {
|
||||
static_assert((SrcDesc::GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(
|
||||
dst_vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStrides()[DstVectorAccessDim] == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else{}([&](auto fwd) {
|
||||
static_assert((DstDesc::GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1()
|
||||
: ThreadwiseGenericTensorSliceCopy_v1(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
TData p_buffer[buffer_desc.GetElementSpace()];
|
||||
|
||||
// copy data from src into buffer
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDIm>{};
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths::Modify(
|
||||
src_vector_access_dim, SliceLengths::Get(src_vector_access_dim) / SrcDataPerAccess);
|
||||
|
||||
constexpr auto src_access_lengths_in_src_access_order =
|
||||
src_access_lengths.ReorderGivenNew2Old(SrcDimAccessOrder{});
|
||||
|
||||
static_ford<decltype(src_access_lengths_in_src_access_order)>{}([&](auto src_access_id) {});
|
||||
}
|
||||
|
||||
private:
|
||||
Array<index_t, TData> mSrcSliceOrigin;
|
||||
Array<index_t, TData> mDstSliceOrigin;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcCoordinate,
|
||||
class DstCoordinate,
|
||||
@@ -116,18 +215,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
|
||||
: mSrcSliceOrigin(make_zero_array<index_t, nDim>()),
|
||||
mDstSliceOrigin(make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin,
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
|
||||
: ThreadwiseGenericTensorSliceCopy_v2(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
@@ -148,6 +247,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
}
|
||||
};
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
@@ -216,6 +316,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
});
|
||||
}
|
||||
|
||||
// T can be Sequence or Array
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
@@ -232,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
// private:
|
||||
private:
|
||||
SrcCoordinate mSrcSliceOrigin;
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
@@ -6,9 +6,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Seq>
|
||||
template <class>
|
||||
struct is_valid_sequence_map;
|
||||
|
||||
template <class>
|
||||
struct sequence_map_inverse;
|
||||
|
||||
template <index_t... Is>
|
||||
struct Sequence
|
||||
{
|
||||
@@ -34,6 +37,8 @@ struct Sequence
|
||||
return Number<GetImpl(Number<I>{})>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I>) const
|
||||
{
|
||||
@@ -54,6 +59,18 @@ struct Sequence
|
||||
return Sequence<Type::Get(Number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
// MapOld2New is Sequence<...>
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
|
||||
{
|
||||
static_assert(MapOld2New::GetSize() == GetSize(),
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
|
||||
|
||||
return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Reverse();
|
||||
|
||||
__host__ __device__ static constexpr auto Front()
|
||||
@@ -253,6 +270,7 @@ struct sequence_reverse<Sequence<I0, I1>>
|
||||
template <class Seq>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
// not implemented yet, always return true
|
||||
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
|
||||
|
||||
// TODO: add proper check for is_valid, something like:
|
||||
@@ -261,6 +279,33 @@ struct is_valid_sequence_map
|
||||
// typename sequence_sort<Seq>::SortedSeqType>{};
|
||||
};
|
||||
|
||||
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
|
||||
struct sequence_map_inverse_impl
|
||||
{
|
||||
private:
|
||||
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y{}[XBegin], XBegin);
|
||||
|
||||
public:
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
|
||||
};
|
||||
|
||||
template <class X2Y, class WorkingY2X, index_t XBegin>
|
||||
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
|
||||
{
|
||||
using type = WorkingY2X;
|
||||
};
|
||||
|
||||
template <class X2Y>
|
||||
struct sequence_map_inverse
|
||||
{
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y,
|
||||
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
|
||||
0,
|
||||
X2Y::GetSize()>::type;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
|
||||
@@ -132,7 +132,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 1
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
#else
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
|
||||
Reference in New Issue
Block a user