From 5e5c27a63b1637556a17e17546147da6cb6d732e Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 16 May 2019 13:22:40 -0500 Subject: [PATCH] adding implicit gemm v3 --- .../ConstantMergedTensorDescriptor.hip.hpp | 14 +- src/include/ConstantTensorDescriptor.hip.hpp | 125 +++++++-- src/include/Sequence.hip.hpp | 249 +++++++++++++++++- src/include/blockwise_3d_tensor_op.hip.hpp | 2 +- src/include/blockwise_4d_tensor_op.hip.hpp | 2 +- src/include/blockwise_tensor_slice_op.hip.hpp | 20 +- src/include/common.hip.hpp | 12 +- src/include/functional.hip.hpp | 23 +- ...on_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp | 7 +- 9 files changed, 380 insertions(+), 74 deletions(-) diff --git a/src/include/ConstantMergedTensorDescriptor.hip.hpp b/src/include/ConstantMergedTensorDescriptor.hip.hpp index ab73c8a49d..3c31da0c3c 100644 --- a/src/include/ConstantMergedTensorDescriptor.hip.hpp +++ b/src/include/ConstantMergedTensorDescriptor.hip.hpp @@ -30,11 +30,6 @@ struct ConstantMergedTensorDescriptor }); } - __host__ __device__ static constexpr index_t GetNumOfOriginalDimension() - { - return TensorDesc::GetNumOfDimension(); - } - __host__ __device__ static constexpr index_t GetNumOfDimension() { constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges...); @@ -51,11 +46,16 @@ struct ConstantMergedTensorDescriptor }; constexpr index_t num_lost_dim = static_const_reduce_n{}( - f_calculate_num_of_lost_dim, mod_conv::plus{}); + f_calculate_num_of_lost_dim, std::plus{}); return TensorDesc::GetNumOfDimension() - num_lost_dim; } + __host__ __device__ static constexpr index_t GetNumOfOriginalDimension() + { + return TensorDesc::GetNumOfDimension(); + } + template __host__ __device__ static constexpr bool IsMergedDimension(Number) { @@ -71,7 +71,7 @@ struct ConstantMergedTensorDescriptor template __host__ __device__ static constexpr bool GetStride(Number) { - static_assert(!IsMergedDimension(Number{}, "wrong! A merged dimension does not have uniform stride") + static_assert(!IsMergedDimension(Number{}, "wrong! stride of a merged dimension is undefined") // not implemented } diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index 6dcd16e167..64c7f4408d 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -85,24 +85,35 @@ struct ConstantTensorDescriptor __host__ __device__ static constexpr index_t GetElementSize() { - return accumulate_on_sequence(Lengths{}, mod_conv::multiplies{}, Number<1>{}); + return accumulate_on_sequence(Lengths{}, std::multiplies{}, Number<1>{}); } +#if 0 // c++14 doesn't support constexpr lambdas, has to use this trick instead - struct GetElementSpace_f + struct f_GetElementSpace_impl { template __host__ __device__ constexpr index_t operator()(IDim idim) const { return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim); } + __host__ __device__ constexpr index_t operator()(index_t length, index_t stride) const + { + return (length - 1) * stride; + } }; +#endif template > __host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{}) { +#if 0 index_t element_space_unaligned = - static_const_reduce_n{}(GetElementSpace_f{}, mod_conv::plus{}) + 1; + static_const_reduce_n{}(f_GetElementSpace_impl{}, std::plus{}) + 1; +#else + constexpr index_t element_space_unaligned = accumulate_on_sequence( + (GetLengths() - Number<1>{}) * GetStrides(), std::plus{}, Number<1>{}); +#endif return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); } @@ -140,9 +151,9 @@ struct ConstantTensorDescriptor constexpr auto multi_id = Sequence{}; constexpr auto seq_tmp = - transform_sequences(mod_conv::multiplies{}, multi_id, GetStrides()); + transform_sequences(std::multiplies{}, multi_id, GetStrides()); - return accumulate_on_sequence(seq_tmp, mod_conv::plus{}, Number<0>{}); + return accumulate_on_sequence(seq_tmp, std::plus{}, Number<0>{}); } __host__ __device__ static Array GetMultiIndex(index_t id) @@ -167,34 +178,112 @@ struct ConstantTensorDescriptor } template - __host__ __device__ static constexpr auto Extract(Number... /*extracted_dims...*/) + __host__ __device__ static constexpr auto Extract(Number... extract_dims) { - static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong!"); + static_assert(sizeof...(IDims) <= GetNumOfDimension(), + "wrong! too many number of dimensions to be extracted"); - constexpr auto extracted_lengths = Sequence{})...>{}; - constexpr auto extracted_strides = Sequence{})...>{}; - - return make_ConstantTensorDescriptor(extracted_lenghts, extracted_strides); + return make_ConstantTensorDescriptor(Lengths{}.Extract(extract_dims), + Strides{}.Extract(extract_dims)); } template __host__ __device__ static constexpr auto Slice(Number, Number) { - // not implemented + return make_ConstantTensorDescriptor(Lengths{}.Modify(Number{}, Number{}), + Strides{}); } - template - __host__ device__ static constexpr auto Fold(Number, Sequence) + template + __host__ device__ static constexpr auto Fold(Number, Number...) { - // not implemented - // need to check the Length dimension to be folded is dividable by FoldLengths + constexpr auto fold_intervals = Sequence{}; + + constexpr fold_intervals_product = + accumulate_on_sequence(fold_intervals, std::multiplies{}, Number<1>{}); + + constexpr auto unfold_length = GetLength(Number{}); + constexpr auto unfold_stride = GetStride(Number{}); + + // length of the dimension to be folded needs to be dividable by fold_interval_product, + // otherwise, folding is invalid + static_assert(unfold_length % fold_interval_product == 0, + "wrong! length on the dimension to be folded cannot be evenly divided!"); + + // folded lengths + constexpr auto fold_lengths = + Sequence{}.Append(fold_intervals); + + // folded strides + constexpr auto fold_strides = transform_sequences(mod_conv::scales{}, + reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies{}); + + // left and right lengths + constexpr auto lengths_pair = GetLengths().Split(Number{}); + constexpr auto left_lengths = lengths_pair.first; + constexpr auto right_lengths = lengths_pair.second.PopFront(); + + // left and right strides + constexpr auto strides_pair = GetStrides().Split(Number{}); + constexpr auto left_strides = strides_pair.first; + constexpr auto right_strides = strides_pair.second.PopFront(); + + return make_ConstantTensorDescriptor(left_lengths.Append(fold_lengths).Append(right_lengths), + left_strides.Append(fold_strides).Append(right_strides)); } template __host__ __device__ static constexpr auto Unfold(Number, Number) { - // not implemented - // need to check the dimensions to be unfold are packed, otherwise, Unfold is not permitted + static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && + FirstUnfoldDim <= LastUnfoldDim, + "wrong! should have FirstUnfoldDim <= LastUnfoldDim!"); + + // dimensions to be unfold need to be in descending order (w.r.t. strides), and need to be + // packed in memory, otherwise, unfolding is invalid + static_for{}([&](auto IDim) { + static_assert( + GetStride(IDim) >= GetStride(Number{}), + "wrong! dimensions to be unfolded need to be in descending order w.r.t strides"); + + static_assert(GetStride(IDim + 1) * GetLength(IDim + 1) == GetStride(IDim), + "wrong! dimensions to be unfolded need to be packed"); + }); + + // lengths + constexpr auto lens_pair1 = Lengths{}.Split(Number{}); + + constexpr auto right_lengths = lens_pair1.second; + + constexpr auto lens_pair2 = lens_pair1.first.Split(Number{}); + + constexpr auto left_lengths = lens_pair2.first; + + constexpr auto fold_lengths = lens_pair2.second; + + constexpr index_t unfold_length = + accumulate_on_sequence(fold_lengths, std::multiplies{}, Number<1>{}); + + constexpr auto new_strides = + left_strides.PopBack(Number{}).Append(right_strides); + + // strides + constexpr auto strides_pair1 = Strides{}.Split(Number{}); + + constexpr auto right_strides = strides_pair1.second; + + constexpr auto strides_pair2 = strides_pair1.first.Split(Number{}); + + constexpr auto left_strides = strides_pair2.first; + + constexpr auto fold_strides = strides_pair2.second; + + constexpr index_t unfold_stride = fold_strides.Back(); + + constexpr auto new_strides = + left_strides.PushBack(Number{}).Append(right_strides); + + return make_ConstantTensorDescriptor(new_lengths, new_strides); } template diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp index 3826e4df97..6b87885780 100644 --- a/src/include/Sequence.hip.hpp +++ b/src/include/Sequence.hip.hpp @@ -2,6 +2,15 @@ #include "constant_integral.hip.hpp" #include "functional.hip.hpp" +struct EmptySequence +{ + template + __host__ __device__ constexpr Seq Append(Seq) const + { + return {}; + } +}; + template struct Sequence { @@ -39,6 +48,11 @@ struct Sequence assert(false); } + __host__ __device__ constexpr auto Reverse() const + { + // not implemented + } + __host__ __device__ constexpr index_t Front() const { return mData[0]; } __host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; } @@ -59,25 +73,192 @@ struct Sequence __host__ __device__ constexpr auto PopBack() const; - template - __host__ __device__ constexpr auto Insert(Number, Number) const + template + __host__ __device__ constexpr auto Append(Sequence) const { - index_t data[mSize + 1]; + return Sequence{}; + } - static_for<0, I, 1>{}([&](auto Iter) { - constexpr index_t iter = Iter.Get(); - data[iter] = mData[iter]; - }); + __host__ __device__ constexpr auto Append(EmptySequence) const { return Type{}; } - data[I] = X; + template + __host__ __device__ constexpr auto Extract(Number...) const + { + return Sequence)...>{}; + } - static_for{}([&](auto Iter) { - constexpr index_t iter = Iter.Get(); - data[iter + 1] = mData[iter]; + template + struct split_impl + { + template + __host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const + { + constexpr new_first = FirstSeq{}.PushBack(Number{}); + constexpr new_second = SecondSeq{}.PopFront(); + + static_if<(N > 0)>{}([&](auto fwd) { + return split_impl{}(new_first, fwd(new_second)); + }).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); }); + } + }; + + // split one sequence to two sequnces: [0, I) and [I, nSize) + // return type is std::pair + template + __host__ __device__ constexpr auto Split(Number) const + { + static_assert(I <= nSize, "wrong! split position is too high!"); + + static_if<(I == 0)>{}( + [&](auto fwd) { return std::make_pair(EmptySequence<>{}, fwd(Type{})); }); + + static_if<(I == nSize)>{}( + [&](auto fwd) { return std::make_pair(Type<>{}, fwd(EmptySequence<>{})); }); + + static_if<(I > 0 && I < nSize)>{}([&](auto fforwader) { + constexpr auto first = Sequence {} + constexpr auto second = Type{}.PopFront(); + + return split_impl{}(first, fwd(second)); }); } + + template + __host__ __device__ constexpr auto Modify(Number, Number) const + { + constexpr auto first_second = Split(Number{}); + + constexpr auto left = first_second.first; + constexpr auto right = first_second.second.PopFront(); + + return left.PushBack(Number{}).Append(right); + } }; +template +__host__ __device__ auto make_increasing_sequence(Number, Number, Number) +{ + static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!"); + + // not implemented +} + +template +__host__ __device__ auto make_uniform_sequence(Number, Number); +{ + // not implemented +} + +template +__host__ __device__ constexpr auto operator+(Sequence, Sequence) const +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs + Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence seq_y) const +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + static_for<0, xs.GetSize(), 1>{}([&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I)); }); + + return Sequence<(Xs - Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Sequence)const +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs * Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Sequence) const +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs / Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Sequence) const +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs % Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Sequence) const +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs % Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator+(Sequence, Number) const +{ + return seq_x + make_uniform_sequence(Number, Number{}); +} + +template +__host__ __device__ constexpr auto operator-(Sequence, Number) const +{ + return seq_x - make_uniform_sequence(Number, Number{}); +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Number)const +{ + return seq_x * make_uniform_sequence(Number, Number{}); +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Number) const +{ + return seq_x / make_uniform_sequence(Number, Number{}); +} + +template +__host__ __device__ constexpr auto operator%(Sequence seq_x, Number y) const +{ + return seq_x % make_uniform_sequence(Number, Number{}); +} + +template +__host__ __device__ constexpr auto operator+(Number, Sequence) const +{ + return make_uniform_sequence(Number{}, Number{}) + Sequence{}; +} + +template +__host__ __device__ constexpr auto operator-(Number, Sequence) const +{ + return make_uniform_sequence(Number{}, Number{}) - Sequence{}; +} + +template +__host__ __device__ constexpr auto operator*(Number, Sequence)const +{ + return make_uniform_sequence(Number{}, Number{}) * Sequence{}; +} + +template +__host__ __device__ constexpr auto operator/(Number, Sequence) const +{ + return make_uniform_sequence(Number{}, Number{}) / Sequence{}; +} + +template +__host__ __device__ constexpr auto operator%(Number, Sequence) const +{ + return make_uniform_sequence(Number{}, Number{}) % Sequence{}; +} + template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { @@ -177,6 +358,12 @@ __host__ __device__ constexpr auto #if 1 // TODO: fix these mess +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence) +{ + return Sequence{}; +} + template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { @@ -248,7 +435,7 @@ __host__ __device__ constexpr auto Sequence::PopBack() const } template -struct accumulate_on_sequence_f +struct accumulate_on_sequence_impl { template __host__ __device__ constexpr index_t operator()(IDim) const @@ -262,6 +449,42 @@ __host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number /*initial_value*/) { constexpr index_t a = - static_const_reduce_n{}(accumulate_on_sequence_f{}, Reduce{}); + static_const_reduce_n{}(accumulate_on_sequence_impl{}, Reduce{}); return Reduce{}(a, I); } + +template +struct scan_sequence_impl +{ + template + __host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const + { + static_assert(RemainSeq{}.GetSize() == NRemain, + "wrong! RemainSeq and NRemain not consistent!"); + + constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front()); + constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number{}); + + static_if<(NRemain > 1)>{}([&](auto fwd) { + return scan_sequence_impl{}( + scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{})); + }).else_([&](auto fwd) { return fwd(scaned_seq); }); + } +}; + +template +__host__ __device__ constexpr auto scan_sequence(Seq, Reduce) +{ + constexpr auto scaned_seq = Sequence{}; + constexpr auto remain_seq = Seq{}.PopFront(); + + constexpr index_t remain_size = Seq::GetSize() - 1; + + return scan_sequence_impl{}(scaned_seq, remain_seq, Reduce{}); +} + +template +__host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce) +{ + return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse(); +} diff --git a/src/include/blockwise_3d_tensor_op.hip.hpp b/src/include/blockwise_3d_tensor_op.hip.hpp index f82dec7f7b..3e9ae6920c 100644 --- a/src/include/blockwise_3d_tensor_op.hip.hpp +++ b/src/include/blockwise_3d_tensor_op.hip.hpp @@ -156,7 +156,7 @@ struct Blockwise3dTensorCopy3 "wrrong! BlockSize is not big enough for ThreadPerDims!"); constexpr index_t num_active_thread = - accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies{}, Number<1>{}); + accumulate_on_sequence(ThreadPerDims{}, std::multiplies{}, Number<1>{}); if(BlockSize > num_active_thread) { diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index 8235575a2f..17c05571a2 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -495,7 +495,7 @@ struct Blockwise4dTensorCopy3 "wrrong! BlockSize is not big enough for ThreadPerDims!"); constexpr index_t num_active_thread = - accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies{}, Number<1>{}); + accumulate_on_sequence(ThreadPerDims{}, std::multiplies{}, Number<1>{}); if(BlockSize > num_active_thread) { diff --git a/src/include/blockwise_tensor_slice_op.hip.hpp b/src/include/blockwise_tensor_slice_op.hip.hpp index 5f7284dc2a..5a1b8cb9e8 100644 --- a/src/include/blockwise_tensor_slice_op.hip.hpp +++ b/src/include/blockwise_tensor_slice_op.hip.hpp @@ -133,7 +133,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; constexpr auto src_data_per_cluster_per_dims = transform_sequences( - mod_conv::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); + std::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); constexpr auto repeat_lengths = transform_sequences(mod_conv::integer_divide_ceiler{}, @@ -141,7 +141,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 src_data_per_cluster_per_dims); constexpr auto thread_tensor_lengths = transform_sequences( - mod_conv::multiplies{}, thread_sub_tensor_lengths, repeat_lengths); + std::multiplies{}, thread_sub_tensor_lengths, repeat_lengths); constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths); @@ -154,7 +154,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; constexpr auto src_data_per_cluster_per_dims = transform_sequences( - mod_conv::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); + std::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); constexpr auto repeat_lengths = transform_sequences(mod_conv::integer_divide_ceiler{}, @@ -162,7 +162,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 src_data_per_cluster_per_dims); constexpr auto thread_tensor_lengths = transform_sequences( - mod_conv::multiplies{}, thread_sub_tensor_lengths, repeat_lengths); + std::multiplies{}, thread_sub_tensor_lengths, repeat_lengths); constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths); @@ -170,10 +170,10 @@ struct BlockwiseTensorSliceReorderCopy_v3 constexpr auto repeat_multi_id = decltype(repeat_multi_id_){}; constexpr auto src_data_multi_id = transform_sequences( - mod_conv::multiplies{}, repeat_multi_id, src_data_per_cluster_per_dims); + std::multiplies{}, repeat_multi_id, src_data_per_cluster_per_dims); constexpr auto clipboard_data_multi_id = transform_sequences( - mod_conv::multiplies{}, repeat_multi_id, thread_sub_tensor_lengths); + std::multiplies{}, repeat_multi_id, thread_sub_tensor_lengths); constexpr index_t src_offset = SrcDesc{}.Get1dIndex(src_data_multi_id); constexpr index_t clipboard_offset = @@ -194,7 +194,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 constexpr auto thread_sub_tensor_lengths = SrcSubLengths{}; constexpr auto src_data_per_cluster_per_dims = transform_sequences( - mod_conv::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); + std::multiplies{}, thread_sub_tensor_lengths, SrcClusterLengths{}); constexpr auto repeat_lengths = transform_sequences(mod_conv::integer_divide_ceiler{}, @@ -202,7 +202,7 @@ struct BlockwiseTensorSliceReorderCopy_v3 src_data_per_cluster_per_dims); constexpr auto thread_tensor_lengths = transform_sequences( - mod_conv::multiplies{}, thread_sub_tensor_lengths, repeat_lengths); + std::multiplies{}, thread_sub_tensor_lengths, repeat_lengths); constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths); @@ -210,10 +210,10 @@ struct BlockwiseTensorSliceReorderCopy_v3 constexpr auto repeat_multi_id = decltype(repeat_multi_id_){}; constexpr auto clipboard_data_multi_id = transform_sequences( - mod_conv::multiplies{}, repeat_multi_id, thread_sub_tensor_lengths); + std::multiplies{}, repeat_multi_id, thread_sub_tensor_lengths); constexpr auto src_data_multi_id = transform_sequences( - mod_conv::multiplies{}, repeat_multi_id, src_data_per_cluster_per_dims); + std::multiplies{}, repeat_multi_id, src_data_per_cluster_per_dims); // reorder src_data_multi_id to get dst_data_multi_id constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{}); diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index bf7249ca70..2c5ee402ae 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -27,16 +27,10 @@ struct is_same }; namespace mod_conv { // namespace mod_conv -template -struct multiplies +template +struct scales { - __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } -}; - -template -struct plus -{ - __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } + __host__ __device__ constexpr T operator()(T a) const { return s * a; } }; template diff --git a/src/include/functional.hip.hpp b/src/include/functional.hip.hpp index 177270367f..79b5b22cc3 100644 --- a/src/include/functional.hip.hpp +++ b/src/include/functional.hip.hpp @@ -10,6 +10,14 @@ struct forwarder } }; +#if 0 +template +__host__ __device__ constexpr auto unpacker(F f) +{ + return [=](auto xs_array){ f(xs...); }; +} +#endif + // Emulate compile time if statement for C++14 // Get the idea from // "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html" @@ -87,7 +95,7 @@ struct static_for_impl } }; -// F signature: F(Number) +// F signature: F(Number) template struct static_for { @@ -97,9 +105,8 @@ struct static_for static_assert((NEnd - NBegin) % Increment == 0, "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); - static_if < NBegin{}([&](auto forwarder) { - static_for_impl{}(f); - }); + static_if<(NBegin < End)>{}( + [&](auto fwd) { static_for_impl{}(f); }); } }; @@ -127,11 +134,3 @@ struct static_const_reduce_n<1> return f(Number<0>{}); } }; - -#if 0 -template -__host__ __device__ constexpr auto unpacker(F f) -{ - return [=](auto xs_array){ f(xs...); }; -} -#endif diff --git a/src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp index be0804d508..0ff514e398 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hip.hpp @@ -91,7 +91,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw // input tensor // memory layout descriptor in device memory [N0, N1, N2, C, H, W] constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc = - in_n_c_h_w_global_desc.Fold(I0, Sequence{}); + in_n_c_h_w_global_desc.Fold(I0, Number{}, Number{}); // merged tensor descriptor in device memory [N1, N2, C, B], src of blockwise copy constexpr auto in_n1_n2_c_b_global_merged_desc = @@ -132,7 +132,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw // weight tensor // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(Sequence<0, 3>{}); + constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); // tensor descriptor in LDS, dst of blockwise copy // be careful of LDS alignment @@ -257,7 +257,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw // output memory layout descriptor in device memory constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = - out_n_k_h_w_global.Fold(I1, Sequence{}).Fold(I0, Sequence{}); + out_n_k_h_w_global.Fold(I1, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); // output merged tensor descriptor in device memory, dst of threadwise copy constexpr auto out_k0_k1_k2_n1_b_n2_global_merged_desc =