mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
adding implicit gemm v4r4
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
@@ -373,6 +374,64 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class TData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcCoordinate,
|
||||
class DstCoordinate,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class DataClusterLengths,
|
||||
class ThreadClusterArrangeOrder>
|
||||
struct BlockwiseGenericTensorSliceCopy_v2
|
||||
{
|
||||
using ThreadwiseCopy = ThreadwiseGenericTensorSliceCopy_v2<TData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcCoordinate,
|
||||
DstCoordinate,
|
||||
SubLengths>;
|
||||
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2(SrcCoordinate src_block_slice_origin,
|
||||
DstCoordinate dst_block_slice_origin)
|
||||
{
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
DataClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
const auto thread_cluster_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_multi_id =
|
||||
reorder_array_given_old2new(thread_cluster_multi_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{};
|
||||
|
||||
mThreadwiseCopy.SetSrcSliceOrigin(src_block_slice_origin + thread_data_multi_id_begin);
|
||||
mThreadwiseCopy.SetDstSliceOrigin(dst_block_slice_origin + thread_data_multi_id_begin);
|
||||
}
|
||||
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
mThreadwiseCopy.Run(p_src, p_dst);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
mThreadwiseCopy.MoveSrcSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
__device__ void MoveDstSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
mThreadwiseCopy.MoveDstSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
// private:
|
||||
ThreadwiseCopy mThreadwiseCopy;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
|
||||
@@ -105,5 +106,75 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class TData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcCoordinate,
|
||||
class DstCoordinate,
|
||||
class SliceLengths>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2()
|
||||
: mSrcSliceOrigin(make_zero_array<index_t, nDim>()),
|
||||
mDstSliceOrigin(make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2(SrcCoordinate src_slice_origin,
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(DstCoordinate dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst[(mDstSliceOrigin + data_id).GetOffset()] =
|
||||
p_src[(mSrcSliceOrigin + data_id).GetOffset()];
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
if(positive_direction)
|
||||
{
|
||||
mSrcSliceOrigin += step_sizes;
|
||||
}
|
||||
else
|
||||
{
|
||||
mSrcSliceOrigin -= step_sizes;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
{
|
||||
if(positive_direction)
|
||||
{
|
||||
mDstSliceOrigin += step_sizes;
|
||||
}
|
||||
else
|
||||
{
|
||||
mDstSliceOrigin -= step_sizes;
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
SrcCoordinate mSrcSliceOrigin;
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user