WIP: PackedCast v3.

This commit is contained in:
Ville Pietilä
2025-08-13 15:13:35 +00:00
parent 11baf3de0c
commit ade741dd45

View File

@@ -14,6 +14,146 @@
namespace ck {
template <typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
InMemoryDataOperationEnum DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun>
struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t SrcScalarPerVector = 1;
static constexpr index_t DstScalarPerVector = 2
using BaseTransfer = ThreadwiseTensorSliceTransfer_v1r3<
float, ck::bhalf_t, SrcDesc, DstDesc, ElementwiseOperation, SliceLengths,
DimAccessOrder, DstVectorDim, SrcScalarPerVector, DstInMemOp,
DstScalarStrideInVector, DstResetCoordinateAfterRun, PackedInput>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_packed_cast(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const ElementwiseOperation& element_op) : base_transfer_(dst_desc, dst_slice_origin_idx, element_op)
{
}
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcSliceOriginIdx& src_slice_origin_idx,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(base_transfer.SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, SrcScalarPerVector>{}, Number<base_transfer.nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr index_t num_pairs = num_access / 2;
constexpr bool has_odd_element = (num_access % 2 == 1);
static_for<0, num_pairs, 1>{}([&](auto i_pair)
{
constexpr auto idx_1d_0 = I2 * i_pair;
constexpr auto idx_1d_1 = I2 * i_pair + I1;
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
float& val_0 = src_buf(Number<src_offset_0>{});
float& val_1 = src_buf(Number<src_offset_1>{});
//static_cast_float_to_bhalf_packed_v2(val_0, val_1);
// Fill the packed value into the dst_vector
});
static_for<0, num_access, 1>{}([&](auto idx_1d)
{
// We need map the odd indices to the even indices, since
// the even indices contain a packed bf16x2 value, where
// the first value contains the bf16 value for the corresponding even index
// and the second value contains the bf16 value for the odd index following the even index.
// The odd indices are not used, so we can just ignore them.
constexpr auto pair_index = idx_1d % I2;
constexpr auto idx_src_1d = idx_1d - pair_index;
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
union
{
float src_float;
bhalf16_t src_bf16x2;
} packed_value;
packed_value.src_float = src_buf[Number<src_offset>{}];
DstData v;
base_transfer.get_element_op()(v, packed_value.src_bf16x2[pair_index.value]);
dst_vector.template AsType<DstData>()(I0) = v;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
dst_desc, base_transfer_.get_dst_coord(), make_tensor_coordinate_step(dst_desc, forward_step));
}
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, base_transfer.GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, base_transfer_.get_dst_coord(), dst_reset_step);
}
}
private:
BaseTransfer base_transfer_;
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
@@ -35,7 +175,6 @@ template <typename SrcData,
InMemoryDataOperationEnum DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun,
bool PackedInput = false,
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3
{
@@ -45,9 +184,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static constexpr index_t nDim = SliceLengths::Size();
static constexpr bool float_input_and_bf16_output_ =
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
@@ -110,60 +246,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr (PackedInput && float_input_and_bf16_output_)
{
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
static_for<0, num_access, 1>{}([&](auto idx_1d)
{
// We need map the odd indices to the even indices, since
// the even indices contain a packed bf16x2 value, where
// the first value contains the bf16 value for the corresponding even index
// and the second value contains the bf16 value for the odd index following the even index.
// The odd indices are not used, so we can just ignore them.
constexpr auto pair_index = idx_1d % I2;
constexpr auto idx_src_1d = idx_1d - pair_index;
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
union
{
float src_float;
bhalf16_t src_bf16x2;
} packed_value;
packed_value.src_float = src_buf[Number<src_offset>{}];
DstData v;
element_op_(v, packed_value.src_bf16x2[pair_index.value]);
dst_vector.template AsType<DstData>()(I0) = v;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
}
else
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
@@ -196,7 +279,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
}
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
@@ -246,7 +328,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
DestCoord& get_dest_coord() { return dst_coord_; }
const ElementwiseOperation& get_element_op() const { return element_op_; }
private:
DstCoord dst_coord_;
const ElementwiseOperation element_op_;
}; // namespace ThreadwiseTensorSliceTransfer_v1r3