mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Code refactor (#175)
* format * improving pipeline * fix typo * format * adding thread group * adding thread group * adding thread group * adding gemm pipeline * tweak * refactor * refactor * add missing type convert * refactor * refactor * refactor * clean * fix build * refactor * format * clean up * use remove_cvref_t * clean * clean up * clean up * clean up
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "xdlops_gemm.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "thread_group.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -25,7 +24,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
@@ -55,7 +56,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
@@ -122,8 +123,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
@@ -339,4 +340,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -45,8 +45,8 @@ struct BlockwiseTensorSliceTransfer_v5r1
|
||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
@@ -13,7 +11,7 @@ namespace ck {
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
template <typename ThreadGroup,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
@@ -35,7 +33,7 @@ template <index_t BlockSize,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct BlockwiseTensorSliceTransfer_v4r1
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
@@ -43,7 +41,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
@@ -58,8 +56,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
dst_element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
@@ -69,14 +67,14 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
}
|
||||
@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
@@ -124,8 +122,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
@@ -133,8 +131,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
@@ -13,10 +11,10 @@ namespace ck {
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
@@ -28,19 +26,19 @@ template <index_t BlockSize,
|
||||
index_t ScalarPerVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r1
|
||||
struct ThreadGroupTensorSliceTransfer_v6r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_desc,
|
||||
@@ -48,25 +46,25 @@ struct BlockwiseTensorSliceTransfer_v6r1
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
@@ -83,8 +81,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
|
||||
}
|
||||
@@ -92,8 +90,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
@@ -101,8 +99,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
@@ -130,4 +128,3 @@ struct BlockwiseTensorSliceTransfer_v6r1
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
@@ -13,10 +11,10 @@ namespace ck {
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. It does not keep reference to tensor descriptor
|
||||
// 3. Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename Src0Data,
|
||||
@@ -31,21 +29,21 @@ template <index_t BlockSize,
|
||||
bool ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r2
|
||||
struct ThreadGroupTensorSliceTransfer_v6r2
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src0_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src1_desc,
|
||||
@@ -55,26 +53,26 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
|
||||
}
|
||||
@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
|
||||
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
|
||||
}
|
||||
@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
|
||||
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
|
||||
}
|
||||
@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
@@ -13,10 +11,10 @@ namespace ck {
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename Src0Data,
|
||||
@@ -34,23 +32,23 @@ template <index_t BlockSize,
|
||||
bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
bool ThreadTransferSrc2ResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r3
|
||||
struct ThreadGroupTensorSliceTransfer_v6r3
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const Src2Desc& src2_desc,
|
||||
const Index& src2_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const Src2Desc& src2_desc,
|
||||
const Index& src2_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src0_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src1_desc,
|
||||
@@ -62,24 +60,24 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<Src2Desc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(
|
||||
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
|
||||
@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
|
||||
}
|
||||
@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
|
||||
}
|
||||
@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
|
||||
}
|
||||
@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user