bf16A_Int8B with fastgelu/bias (#1264)

* changed the copy function to v7r2

* adding multi_abd

* in-progress

* add post-load oob check

* debugging

* adjust instances

* add run_lds

* add elemntwise_op

* replace multi_abd_device with v3

* clean up

* clean

* clean

* Added LDSType

* profiling

* adjust oobcheck

* add missing file

* refactor

* clean

* add examples
This commit is contained in:
zjing14
2024-04-26 07:26:30 -05:00
committed by GitHub
parent b4032629e5
commit 0d0150db20
37 changed files with 4752 additions and 970 deletions

View File

@@ -41,7 +41,8 @@ template <typename ThreadGroup,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags>
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r2
{
static constexpr index_t nDim =
@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template <typename SrcBuffers>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs);
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs));
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
}
}
@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
SrcScalarPerVector,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags>;
ThreadTransferDstResetCoordinateAfterRunFlags,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};