mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Remove obsolete packed cast tensor slice transfers.
This commit is contained in:
@@ -49,206 +49,6 @@ struct LinearIndexFinder
|
||||
}
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is StaticBuffer
|
||||
// 3. SrcSliceOrginIdx is known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is not known at compile-time
|
||||
// 2. DstBuffer is DynamicBuffer
|
||||
// 3. DstSliceOrginIdx is not known at compile time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
bool PackedInput = false,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
static constexpr bool SerpentineAccessPattern = true;
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_pass_through(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const ElementwiseOperation& element_op)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
|
||||
element_op_{element_op}
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
// For now, SrcData must be float and DstData must be ck::bhalf_t
|
||||
static_assert(std::is_same_v<SrcData, float>,
|
||||
"wrong! SrcData must be float");
|
||||
static_assert(std::is_same_v<DstData, ck::bhalf_t>,
|
||||
"wrong! DstData must be bhalf_t");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurveDst = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
|
||||
using SpaceFillingCurveSrc = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
false>;
|
||||
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurveDst::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurveDst::ScalarPerVector");
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurveDst::GetNumOfAccess();
|
||||
|
||||
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
|
||||
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
|
||||
|
||||
static_assert(1 == SpaceFillingCurveDst::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
|
||||
|
||||
static_assert(SpaceFillingCurveDst::GetNumOfAccess() == SpaceFillingCurveSrc::GetNumOfAccess(),
|
||||
"wrong! SpaceFillingCurveDst and SpaceFillingCurveSrc must have the same number of access.");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr index_t idx_src_1d = LinearIndexFinder<SpaceFillingCurveSrc, SpaceFillingCurveDst, idx_1d.value, 0, num_access>::find();
|
||||
static_assert(idx_src_1d != index_t(-1), "wrong! Cannot find linear index.");
|
||||
|
||||
// Map linear index to the packed BF16 index
|
||||
constexpr index_t idx_src_1d_packed = idx_src_1d / 2;
|
||||
constexpr index_t pair_index = idx_src_1d % 2;
|
||||
|
||||
constexpr auto idx_md_src = SpaceFillingCurveSrc::GetIndex(Number<idx_src_1d_packed>{});
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src);
|
||||
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
ck::bhalf2_t bf16x2;
|
||||
} converter;
|
||||
converter.fp32 = src_buf[Number<src_offset>{}];
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
converter.bf16x2[pair_index]);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurveDst::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
}; // namespace ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
@@ -483,417 +283,6 @@ private:
|
||||
DstCoord dst_coord_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is StaticBuffer
|
||||
// 3. SrcSliceOrginIdx is known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is not known at compile-time
|
||||
// 2. DstBuffer is DynamicBuffer
|
||||
// 3. DstSliceOrginIdx is not known at compile time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3_buffered_packed_cast
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr bool SerpentineAccessPattern = false;
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_buffered_packed_cast(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const ElementwiseOperation&)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
// Assert that elementwise op is pass through.
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<ElementwiseOperation>, ck::tensor_operation::element_wise::PassThrough>,
|
||||
"wrong! ElementwiseOperation must be PassThrough");
|
||||
|
||||
// For now, SrcData must be float and DstData must be ck::bhalf_t
|
||||
static_assert(std::is_same_v<SrcData, float>,
|
||||
"wrong! SrcData must be float");
|
||||
static_assert(std::is_same_v<DstData, ck::bhalf_t>,
|
||||
"wrong! DstData must be bhalf_t");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcBuffer& src_buf)
|
||||
{
|
||||
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr index_t num_pairs = num_access / 2;
|
||||
constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
|
||||
// TODO: Enable also odd number of elements.
|
||||
static_assert(!has_odd_element, "wrong!Slice should have even number of elements.");
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
{
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
|
||||
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
|
||||
|
||||
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
|
||||
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
|
||||
|
||||
const float val_0 = src_buf[Number<src_offset_0>{}];
|
||||
const float val_1 = src_buf[Number<src_offset_1>{}];
|
||||
|
||||
const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
|
||||
|
||||
// Store the packed value into the thread scratch buffer
|
||||
thread_scratch_.template SetAsType<ck::bhalf2_t>(Number<idx_1d_0>{}, packed_value);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
static_assert(num_access == buffer_size_, "wrong!num_access != buffer_size_");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d)
|
||||
{
|
||||
const ck::bhalf_t val = thread_scratch_.template GetAsType<ck::bhalf_t>(Number<idx_1d>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
val);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
RunRead<SrcSliceOriginIdx,SrcBuffer>(src_buf);
|
||||
RunWrite(dst_desc, dst_buf);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
static constexpr index_t ScratchVectorSize = 2;
|
||||
|
||||
__device__ static constexpr index_t GetBufferSize()
|
||||
{
|
||||
return reduce_on_sequence(SliceLengths{}, math::multiplies{}, Number<1>{});
|
||||
}
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
|
||||
static constexpr auto buffer_size_ = GetBufferSize();
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
ck::bhalf_t,
|
||||
buffer_size_ / ScratchVectorSize,
|
||||
ScratchVectorSize,
|
||||
true> thread_scratch_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is StaticBuffer
|
||||
// 3. SrcSliceOrginIdx is known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is not known at compile-time
|
||||
// 2. DstBuffer is DynamicBuffer
|
||||
// 3. DstSliceOrginIdx is not known at compile time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3_vectorized
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr bool SnakedAccess = true;
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_vectorized(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const ElementwiseOperation&)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
// Assert that elementwise op is pass through.
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<ElementwiseOperation>, ck::tensor_operation::element_wise::PassThrough>,
|
||||
"wrong! ElementwiseOperation must be PassThrough");
|
||||
|
||||
// For now, SrcData must be float and DstData must be ck::bhalf_t
|
||||
static_assert(std::is_same_v<SrcData, float>,
|
||||
"wrong! SrcData must be float");
|
||||
static_assert(std::is_same_v<DstData, ck::bhalf_t>,
|
||||
"wrong! DstData must be bhalf_t");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SnakedAccess>;
|
||||
|
||||
static_assert(2 == SpaceFillingCurve::ScalarPerVector, "wrong!2 != SpaceFillingCurve::ScalarPerVector");
|
||||
ck::bhalf2_t dst_vector;
|
||||
using dst_vector_t = ck::bhalf2_t;
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d)
|
||||
{
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
constexpr auto idx_src_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
|
||||
constexpr auto idx_src_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md + src_scalar_step_in_vector);
|
||||
|
||||
const float val_0 = src_buf[Number<idx_src_0>{}];
|
||||
const float val_1 = src_buf[Number<idx_src_1>{}];
|
||||
dst_vector = bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SnakedAccess>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
|
||||
Reference in New Issue
Block a user