Optimize LDS write order for packed cast.

This commit is contained in:
Ville Pietilä
2025-08-28 15:01:55 +00:00
parent 9f66d9fbca
commit c69539fe3c

View File

@@ -13,6 +13,42 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Helper function to compare multi-indices element by element
template <typename Idx1, typename Idx2>
constexpr bool indices_equal(const Idx1& idx1, const Idx2& idx2)
{
if constexpr (Idx1::Size() != Idx2::Size()) {
return false;
} else {
bool equal = true;
static_for<0, Idx1::Size(), 1>{}([&](auto i) {
equal = equal && (idx1.At(i) == idx2.At(i));
});
return equal;
}
}
template <typename SpaceFillingCurveSrc, typename SpaceFillingCurveDst, index_t DstIdx, index_t SearchIdx, index_t MaxIdx>
struct LinearIndexFinder
{
static constexpr index_t find()
{
if constexpr (SearchIdx >= MaxIdx) {
return index_t(-1); // Not found
} else {
constexpr auto dst_md = SpaceFillingCurveDst::GetIndex(Number<DstIdx>{});
constexpr auto src_md = SpaceFillingCurveSrc::GetIndex(Number<SearchIdx>{});
if constexpr (indices_equal(dst_md, src_md)) {
return SearchIdx;
} else {
return LinearIndexFinder<SpaceFillingCurveSrc, SpaceFillingCurveDst, DstIdx, SearchIdx + 1, MaxIdx>::find();
}
}
}
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
@@ -44,7 +80,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
static constexpr index_t nDim = SliceLengths::Size();
static constexpr bool SerpentineAccessPattern = false;
static constexpr bool SerpentineAccessPattern = true;
static constexpr bool float_input_and_bf16_output_ =
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
@@ -96,86 +132,63 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
using SpaceFillingCurveDst = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>,
SerpentineAccessPattern>;
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
using SpaceFillingCurveSrc = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>,
false>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_assert(DstScalarPerVector == SpaceFillingCurveDst::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurveDst::ScalarPerVector");
constexpr auto num_access = SpaceFillingCurveDst::GetNumOfAccess();
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
static_assert(1 == SpaceFillingCurveDst::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
constexpr index_t num_pairs = num_access / 2;
constexpr bool has_odd_element = (num_access % 2 == 1);
static_assert(SpaceFillingCurveDst::GetNumOfAccess() == SpaceFillingCurveSrc::GetNumOfAccess(),
"wrong! SpaceFillingCurveDst and SpaceFillingCurveSrc must have the same number of access.");
// TODO: Enable also odd number of elements.
static_assert(!has_odd_element, "wrong!Slice should have even number of elements.");
static_for<0, num_pairs, 1>{}([&](auto i_pair)
{
constexpr auto idx_md = SpaceFillingCurve::GetIndex(i_pair);
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
union
{
float src_float;
bhalf2_t src_bf16x2;
} packed_value;
packed_value.src_float = src_buf[Number<src_offset>{}];
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr index_t idx_src_1d = LinearIndexFinder<SpaceFillingCurveSrc, SpaceFillingCurveDst, idx_1d.value, 0, num_access>::find();
static_assert(idx_src_1d != index_t(-1), "wrong! Cannot find linear index.");
// Map linear index to the packed BF16 index
constexpr index_t idx_src_1d_packed = idx_src_1d / 2;
constexpr index_t pair_index = idx_src_1d % 2;
constexpr auto idx_md_src = SpaceFillingCurveSrc::GetIndex(Number<idx_src_1d_packed>{});
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src);
union
{
float fp32;
ck::bhalf2_t bf16x2;
} converter;
converter.fp32 = src_buf[Number<src_offset>{}];
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
constexpr auto idx_1d_0 = I2 * i_pair;
constexpr auto idx_1d_1 = I2 * i_pair + I1;
const auto dst_offset_0 = dst_coord_.GetOffset();
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
dst_coord_.GetOffset(),
is_dst_valid,
converter.bf16x2[pair_index]);
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurveDst::GetForwardStep(idx_1d);
// Move to the next dst coordinate
constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_0));
const auto dst_offset_1 = dst_coord_.GetOffset();
if (dst_offset_1 - dst_offset_0 == 1)
{
// Store the packed value directly since we have consequtive locations is the dst buffer.
dst_buf.template Update<DstInMemOp, ck::bhalf2_t>(
dst_coord_.GetOffset(),
is_dst_valid,
packed_value.src_bf16x2);
}
else
{
// Store the packed value into the dst buffer one by one.
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
dst_offset_0,
is_dst_valid,
packed_value.src_bf16x2[0]);
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
dst_offset_1,
is_dst_valid,
packed_value.src_bf16x2[1]);
}
// Move to next dst coordinate, unless this was the last pair
if constexpr(i_pair.value != num_pairs - 1)
{
constexpr auto forward_step_1 = SpaceFillingCurve::GetForwardStep(idx_1d_1);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_1));
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});