mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
Optimize LDS write order for packed cast.
This commit is contained in:
@@ -13,6 +13,42 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Helper function to compare multi-indices element by element
|
||||
template <typename Idx1, typename Idx2>
|
||||
constexpr bool indices_equal(const Idx1& idx1, const Idx2& idx2)
|
||||
{
|
||||
if constexpr (Idx1::Size() != Idx2::Size()) {
|
||||
return false;
|
||||
} else {
|
||||
bool equal = true;
|
||||
static_for<0, Idx1::Size(), 1>{}([&](auto i) {
|
||||
equal = equal && (idx1.At(i) == idx2.At(i));
|
||||
});
|
||||
return equal;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SpaceFillingCurveSrc, typename SpaceFillingCurveDst, index_t DstIdx, index_t SearchIdx, index_t MaxIdx>
|
||||
struct LinearIndexFinder
|
||||
{
|
||||
static constexpr index_t find()
|
||||
{
|
||||
if constexpr (SearchIdx >= MaxIdx) {
|
||||
return index_t(-1); // Not found
|
||||
} else {
|
||||
constexpr auto dst_md = SpaceFillingCurveDst::GetIndex(Number<DstIdx>{});
|
||||
constexpr auto src_md = SpaceFillingCurveSrc::GetIndex(Number<SearchIdx>{});
|
||||
|
||||
if constexpr (indices_equal(dst_md, src_md)) {
|
||||
return SearchIdx;
|
||||
} else {
|
||||
return LinearIndexFinder<SpaceFillingCurveSrc, SpaceFillingCurveDst, DstIdx, SearchIdx + 1, MaxIdx>::find();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
@@ -44,7 +80,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
static constexpr bool SerpentineAccessPattern = false;
|
||||
static constexpr bool SerpentineAccessPattern = true;
|
||||
|
||||
static constexpr bool float_input_and_bf16_output_ =
|
||||
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
|
||||
@@ -96,86 +132,63 @@ struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
using SpaceFillingCurveDst = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SerpentineAccessPattern>;
|
||||
|
||||
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
using SpaceFillingCurveSrc = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
false>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurveDst::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurveDst::ScalarPerVector");
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurveDst::GetNumOfAccess();
|
||||
|
||||
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
|
||||
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
static_assert(1 == SpaceFillingCurveDst::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
|
||||
|
||||
constexpr index_t num_pairs = num_access / 2;
|
||||
constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
static_assert(SpaceFillingCurveDst::GetNumOfAccess() == SpaceFillingCurveSrc::GetNumOfAccess(),
|
||||
"wrong! SpaceFillingCurveDst and SpaceFillingCurveSrc must have the same number of access.");
|
||||
|
||||
// TODO: Enable also odd number of elements.
|
||||
static_assert(!has_odd_element, "wrong!Slice should have even number of elements.");
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
{
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(i_pair);
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
|
||||
|
||||
union
|
||||
{
|
||||
float src_float;
|
||||
bhalf2_t src_bf16x2;
|
||||
} packed_value;
|
||||
|
||||
packed_value.src_float = src_buf[Number<src_offset>{}];
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr index_t idx_src_1d = LinearIndexFinder<SpaceFillingCurveSrc, SpaceFillingCurveDst, idx_1d.value, 0, num_access>::find();
|
||||
static_assert(idx_src_1d != index_t(-1), "wrong! Cannot find linear index.");
|
||||
|
||||
// Map linear index to the packed BF16 index
|
||||
constexpr index_t idx_src_1d_packed = idx_src_1d / 2;
|
||||
constexpr index_t pair_index = idx_src_1d % 2;
|
||||
|
||||
constexpr auto idx_md_src = SpaceFillingCurveSrc::GetIndex(Number<idx_src_1d_packed>{});
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_src);
|
||||
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
ck::bhalf2_t bf16x2;
|
||||
} converter;
|
||||
converter.fp32 = src_buf[Number<src_offset>{}];
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
|
||||
const auto dst_offset_0 = dst_coord_.GetOffset();
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
converter.bf16x2[pair_index]);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurveDst::GetForwardStep(idx_1d);
|
||||
|
||||
// Move to the next dst coordinate
|
||||
constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_0));
|
||||
|
||||
const auto dst_offset_1 = dst_coord_.GetOffset();
|
||||
|
||||
if (dst_offset_1 - dst_offset_0 == 1)
|
||||
{
|
||||
// Store the packed value directly since we have consequtive locations is the dst buffer.
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf2_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
packed_value.src_bf16x2);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Store the packed value into the dst buffer one by one.
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_offset_0,
|
||||
is_dst_valid,
|
||||
packed_value.src_bf16x2[0]);
|
||||
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_offset_1,
|
||||
is_dst_valid,
|
||||
packed_value.src_bf16x2[1]);
|
||||
}
|
||||
|
||||
// Move to next dst coordinate, unless this was the last pair
|
||||
if constexpr(i_pair.value != num_pairs - 1)
|
||||
{
|
||||
|
||||
constexpr auto forward_step_1 = SpaceFillingCurve::GetForwardStep(idx_1d_1);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_1));
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user