mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
enabling padding for chwn format
This commit is contained in:
@@ -7,6 +7,8 @@
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_coordinate_v2.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
|
||||
@@ -418,6 +420,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
}
|
||||
};
|
||||
|
||||
// This version use TensorCoordiante
|
||||
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst.
|
||||
template <index_t BlockSize,
|
||||
@@ -518,7 +521,7 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
}
|
||||
|
||||
private:
|
||||
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
using RegisterBufferDesc = decltype(make_native_tensor_descriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
@@ -544,6 +547,7 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
// this version use TensorView and TensorCoordinate
|
||||
template <index_t BlockSize,
|
||||
class SrcTensor,
|
||||
class DstTensor,
|
||||
@@ -639,25 +643,25 @@ struct BlockwiseGenericTensorSliceCopy_v3
|
||||
using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
using ThreadBufferTensor = NormalTensorView<ThreadBufferDesc, data_type>;
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v3<SrcTensor,
|
||||
ThreadBufferTensor,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v3r1<SrcTensor,
|
||||
ThreadBufferTensor,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v3<ThreadBufferTensor,
|
||||
DstTensor,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v3r1<ThreadBufferTensor,
|
||||
DstTensor,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
|
||||
data_type mpBuffer[ThreadBufferDesc::GetElementSpace()];
|
||||
|
||||
@@ -667,6 +671,125 @@ struct BlockwiseGenericTensorSliceCopy_v3
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class ThreadClusterLengths,
|
||||
class ThreadClusterArrangeOrder,
|
||||
class SrcDimAccessOrder,
|
||||
class DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v4
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
using SrcCoord = typename TensorCoordinate_v2<SrcDesc>::type;
|
||||
using DstCoord = typename TensorCoordinate_v2<DstDesc>::type;
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v4(SrcCoord src_block_slice_origin,
|
||||
DstCoord dst_block_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
|
||||
nDim == SubLengths::Size() && nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
{
|
||||
return RegisterBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
|
||||
{
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
|
||||
{
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
TData p_buffer[GetRegisterBufferSize()];
|
||||
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveSrcSlicingWindow(T step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveDstSlicingWindow(T step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseStore.MoveDstSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
private:
|
||||
using RegisterBufferDesc = decltype(make_native_tensor_descriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_coordinate_v2.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
|
||||
@@ -427,6 +429,7 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
Array<index_t, nDim> mDstSliceOrigin;
|
||||
};
|
||||
|
||||
// This version use TensorCoordinate
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the dimensions of vector access to be different on src and dst.
|
||||
// It also allows the vector size to be different on src and dst.
|
||||
@@ -774,6 +777,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
// this version use TensorView and TensorCoordinate
|
||||
template <class SrcTensor,
|
||||
class DstTensor,
|
||||
class SliceLengths,
|
||||
@@ -783,7 +787,7 @@ template <class SrcTensor,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v3
|
||||
struct ThreadwiseGenericTensorSliceCopy_v3r1
|
||||
{
|
||||
static constexpr index_t nDim = SrcTensor::GetNumOfDimension();
|
||||
using data_type = remove_cv_t<typename SrcTensor::data_type>;
|
||||
@@ -791,10 +795,10 @@ struct ThreadwiseGenericTensorSliceCopy_v3
|
||||
using SrcCoordinate = typename SrcTensor::coordinate_type;
|
||||
using DstCoordinate = typename DstTensor::coordinate_type;
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3(SrcTensor src,
|
||||
SrcCoordinate src_slice_origin,
|
||||
DstTensor dst,
|
||||
DstCoordinate dst_slice_origin)
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3r1(SrcTensor src,
|
||||
SrcCoordinate src_slice_origin,
|
||||
DstTensor dst,
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrc{src},
|
||||
mDst{dst},
|
||||
mSrcSlice{src.Slice(src_slice_origin, SliceLengths{})},
|
||||
@@ -821,8 +825,8 @@ struct ThreadwiseGenericTensorSliceCopy_v3
|
||||
"wrong! vectorized access is not allowed");
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3()
|
||||
: ThreadwiseGenericTensorSliceCopy_v3(
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3r1()
|
||||
: ThreadwiseGenericTensorSliceCopy_v3r1(
|
||||
SrcTensor{}, SrcCoordinate{}, DstTensor{}, DstCoordinate{})
|
||||
{
|
||||
}
|
||||
@@ -940,5 +944,154 @@ struct ThreadwiseGenericTensorSliceCopy_v3
|
||||
DstSlice mDstSlice;
|
||||
};
|
||||
|
||||
// This version use multi-index transformation
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimensions of vector access should be the same on src and dst.
|
||||
// The dimension access order should be the same on src and dst.
|
||||
// It is designed for cases, where one of src and dst is register, and
|
||||
// the other is device memory or LDS
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class DimAccessOrder,
|
||||
index_t VectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = typename TensorCoordinate_v2<SrcDesc>::type;
|
||||
using DstCoord = typename TensorCoordinate_v2<DstDesc>::type;
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2(SrcCoord src_slice_origin,
|
||||
DstCoord dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<DimAccessOrder>{}, "wrong! map is not valid");
|
||||
|
||||
static_assert(
|
||||
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// TODO:: sanity-check if vectorized memory access is allowed on src and dst
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
|
||||
: ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(SrcCoord src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(DstCoord dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <class TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
|
||||
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}([&](
|
||||
auto long_vector_access_id) {
|
||||
|
||||
// data id w.r.t slicing-window
|
||||
auto long_vector_data_begin_id = long_vector_access_id;
|
||||
long_vector_data_begin_id(vector_access_dim) =
|
||||
long_vector_size * long_vector_access_id[vector_access_dim];
|
||||
|
||||
// buffer to hold a long-vector
|
||||
TData p_long_vector[long_vector_size];
|
||||
|
||||
// set 0
|
||||
for(index_t i = 0; i < long_vector_size; ++i)
|
||||
{
|
||||
p_long_vector[i] = 0;
|
||||
}
|
||||
|
||||
// load data from src to the long-vector buffer
|
||||
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(vector_access_dim) = i * src_data_per_access;
|
||||
|
||||
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
|
||||
|
||||
// check for padding
|
||||
// TODO: still kind of messy
|
||||
if(!src_coord.IsAnyLevelIndexInPaddingArea())
|
||||
{
|
||||
const index_t src_offset =
|
||||
(mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id)).GetOffset();
|
||||
|
||||
const index_t buffer_offset = i * src_data_per_access;
|
||||
|
||||
*reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
}
|
||||
|
||||
// store data from the long-vector buffer to dst
|
||||
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(vector_access_dim) = i * dst_data_per_access;
|
||||
|
||||
const index_t buffer_offset = i * dst_data_per_access;
|
||||
|
||||
const index_t dst_offset =
|
||||
(mDstSliceOrigin + (long_vector_data_begin_id + scalar_id)).GetOffset();
|
||||
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mSrcSliceOrigin += step_sizes;
|
||||
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mDstSliceOrigin += step_sizes;
|
||||
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord mSrcSliceOrigin;
|
||||
DstCoord mDstSliceOrigin;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user