mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Fused packed cast improvements.
This commit is contained in:
@@ -82,9 +82,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
|
||||
static constexpr bool SerpentineAccessPattern = true;
|
||||
|
||||
static constexpr bool float_input_and_bf16_output_ =
|
||||
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
@@ -101,6 +98,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
// For now, SrcData must be float and DstData must be ck::bhalf_t
|
||||
static_assert(std::is_same_v<SrcData, float>,
|
||||
"wrong! SrcData must be float");
|
||||
static_assert(std::is_same_v<DstData, ck::bhalf_t>,
|
||||
"wrong! DstData must be bhalf_t");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
@@ -273,7 +276,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr bool SerpentineAccessPattern = false;
|
||||
|
||||
// Currently these must match, either true or false for both.
|
||||
// Otherwise we increase the register usage and end-up using the scratch memory.
|
||||
// The linear pattern seems to be better tahn serpentine.
|
||||
// Howeverm for LDS writes, serpentine pattern could be better.
|
||||
static constexpr bool SerpentineAccessPatternDst = true;
|
||||
static constexpr bool SerpentineAccessPatternSrc = true;
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
@@ -332,14 +341,24 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
// Note: We cannot easily have different SFC for src and dst because this will lead to increased register usage.
|
||||
using SpaceFillingCurveDst = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
SerpentineAccessPatternDst>;
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
using SpaceFillingCurveSrc = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPatternSrc>;
|
||||
|
||||
static_assert(1 == SpaceFillingCurveDst::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurveDst::ScalarPerVector");
|
||||
|
||||
static_assert(SpaceFillingCurveDst::GetNumOfAccess() == SpaceFillingCurveSrc::GetNumOfAccess(),
|
||||
"wrong! SpaceFillingCurveDst and SpaceFillingCurveSrc must have the same number of access.");
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurveDst::GetNumOfAccess();
|
||||
constexpr index_t num_pairs = num_access / 2;
|
||||
constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
|
||||
@@ -349,13 +368,23 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
ck::float2_t float2_buffer;
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
{
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
|
||||
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
|
||||
|
||||
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
|
||||
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
|
||||
// First of the pair of the float values
|
||||
constexpr auto idx_1d_dst_0 = I2 * i_pair;
|
||||
constexpr index_t idx_1d_src_0 = SerpentineAccessPatternDst != SerpentineAccessPatternSrc
|
||||
? LinearIndexFinder<SpaceFillingCurveSrc, SpaceFillingCurveDst, idx_1d_dst_0.value, 0, num_access>::find()
|
||||
: idx_1d_dst_0.value;
|
||||
static_assert(idx_1d_src_0 != index_t(-1), "wrong! Cannot find first linear index.");
|
||||
constexpr auto idx_md_src_0 = SpaceFillingCurveSrc::GetIndex(Number<idx_1d_src_0>{});
|
||||
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src_0);
|
||||
|
||||
// Second pf the pair of the float values
|
||||
constexpr auto idx_1d_dst_1 = I2 * i_pair + I1;
|
||||
constexpr index_t idx_1d_src_1 = SerpentineAccessPatternDst != SerpentineAccessPatternSrc
|
||||
? LinearIndexFinder<SpaceFillingCurveSrc, SpaceFillingCurveDst, idx_1d_dst_1.value, 0, num_access>::find()
|
||||
: idx_1d_dst_1.value;
|
||||
static_assert(idx_1d_src_1 != index_t(-1), "wrong! Cannot find first linear index.");
|
||||
constexpr auto idx_md_src_1 = SpaceFillingCurveSrc::GetIndex(Number<idx_1d_src_1>{});
|
||||
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src_1);
|
||||
|
||||
if constexpr (src_offset_1 - src_offset_0 == 1)
|
||||
{
|
||||
@@ -373,42 +402,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// Store the packed value into the dst buffer one by one.
|
||||
const auto dst_offset_0 = dst_coord_.GetOffset();
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_offset_0,
|
||||
is_dst_valid,
|
||||
packed_value[0]);
|
||||
|
||||
// Move to the next dst coordinate
|
||||
constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
|
||||
constexpr auto forward_step_0 = SpaceFillingCurveDst::GetForwardStep(idx_1d_dst_0);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_0));
|
||||
|
||||
const auto dst_offset_1 = dst_coord_.GetOffset();
|
||||
|
||||
if (dst_offset_1 - dst_offset_0 == 1)
|
||||
{
|
||||
// Store the packed value directly since we have consequtive locations is the dst buffer.
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf2_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
packed_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Store the packed value into the dst buffer one by one.
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_offset_0,
|
||||
is_dst_valid,
|
||||
packed_value[0]);
|
||||
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_offset_1,
|
||||
is_dst_valid,
|
||||
packed_value[1]);
|
||||
}
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_offset_1,
|
||||
is_dst_valid,
|
||||
packed_value[1]);
|
||||
|
||||
// Move to next dst coordinate, unless this was the last pair
|
||||
if constexpr(i_pair.value != num_pairs - 1)
|
||||
{
|
||||
|
||||
constexpr auto forward_step_1 = SpaceFillingCurve::GetForwardStep(idx_1d_1);
|
||||
constexpr auto forward_step_1 = SpaceFillingCurveDst::GetForwardStep(idx_1d_dst_1);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_1));
|
||||
}
|
||||
@@ -433,7 +449,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
SerpentineAccessPatternDst>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
|
||||
Reference in New Issue
Block a user