From 1ca98e75ced161f6d0cc7fa33b8c59686b12697a Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 26 Aug 2024 09:48:07 +0000 Subject: [PATCH] tempsave --- .../threadwise_tensor_slice_transfer.hpp | 55 +++++++++---------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 3667ae7a3a..0de400f292 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -399,27 +399,27 @@ struct ThreadwiseTensorSliceTransfer_v2 // 1. DstDesc is known at compile-time // 2. DstBuffer is StaticBuffer // 3. dst_slice_origin_idx is known at compile-time -template ::type = false> struct ThreadwiseTensorSliceTransfer_v2r1 { - static_assert((InvalidElementAsNaN && !std::is_integral::value) || + static_assert((InvalidElementAsNaN && !std::is_integral::value) || (!InvalidElementAsNaN), "Filling invalid element as NaN is only for floating point types"); static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nSrc = SrcDescs::Size(); - static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); using Index = MultiIndex; @@ -437,37 +437,36 @@ struct ThreadwiseTensorSliceTransfer_v2r1 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, - const Index& src_slice_origin_idx) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) + __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDescs& src_descs, + const Indexs& src_slice_origin_idxs) { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong! Not divisible"); + + src_coords_(generate_tuple([&](auto i) { return make_tensor_coordinate(src_desc[i], src_slice_origin_idx[i]); }, + nSrc);) } - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs&, + const DstSliceOriginIdxs&, + DstBuffers& dst_bufs) { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); - } - - template - __device__ void Run(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) - { - static_assert(DstDesc::IsKnownAtCompileTime(), + static_for<0, nDst, 1>{}([&](auto i) { + static_assert(remove_cvref_t>::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); - static_assert(is_known_at_compile_time>::value, + static_assert(is_known_at_compile_time>>::value, "wrong! DstSliceOrigin need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); + + static_assert( + is_same::type>, remove_cvref_t>>::value && + "wrong! inconsistent type"); + }); // DstDesc and dst_slice_origin_idx are known at compile-time constexpr auto dst_desc = remove_cvref_t{};