mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
tempsave
This commit is contained in:
@@ -399,27 +399,27 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. dst_slice_origin_idx is known at compile-time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
template <typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t SrcScalarPerVectors,
|
||||
index_t SrcScalarStrideInVectors,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
bool InvalidElementAsNaN = false,
|
||||
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v2r1
|
||||
{
|
||||
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
|
||||
static_assert((InvalidElementAsNaN && !std::is_integral<DstDatas>::value) ||
|
||||
(!InvalidElementAsNaN),
|
||||
"Filling invalid element as NaN is only for floating point types");
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nDst = DstDescs::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
@@ -437,37 +437,36 @@ struct ThreadwiseTensorSliceTransfer_v2r1
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDescs& src_descs,
|
||||
const Indexs& src_slice_origin_idxs)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
src_coords_(generate_tuple([&](auto i) { return make_tensor_coordinate(src_desc[i], src_slice_origin_idx[i]); },
|
||||
nSrc);)
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
template <typename SrcBuffers, typename DstBuffers, typename DstSliceOriginIdxs>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs&,
|
||||
const DstSliceOriginIdxs&,
|
||||
DstBuffers& dst_bufs)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
static_assert(remove_cvref_t<tuple_element_t<i.value, DstDescs>>::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<tuple_element_t<i.value, DstSliceOriginIdxs>>>::value,
|
||||
"wrong! DstSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename tuple_element_t<i.value, DstBuffer>::type>, remove_cvref_t<tuple_element_t<i.value, DstDatas>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
});
|
||||
|
||||
// DstDesc and dst_slice_origin_idx are known at compile-time
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
|
||||
Reference in New Issue
Block a user