mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
WIP: PackedCast v3.
This commit is contained in:
@@ -14,6 +14,146 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr index_t SrcScalarPerVector = 1;
|
||||
static constexpr index_t DstScalarPerVector = 2
|
||||
|
||||
using BaseTransfer = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
float, ck::bhalf_t, SrcDesc, DstDesc, ElementwiseOperation, SliceLengths,
|
||||
DimAccessOrder, DstVectorDim, SrcScalarPerVector, DstInMemOp,
|
||||
DstScalarStrideInVector, DstResetCoordinateAfterRun, PackedInput>;
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_packed_cast(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const ElementwiseOperation& element_op) : base_transfer_(dst_desc, dst_slice_origin_idx, element_op)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcSliceOriginIdx& src_slice_origin_idx,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(base_transfer.SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, SrcScalarPerVector>{}, Number<base_transfer.nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
|
||||
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr index_t num_pairs = num_access / 2;
|
||||
constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
{
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
|
||||
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
|
||||
|
||||
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
|
||||
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
|
||||
|
||||
float& val_0 = src_buf(Number<src_offset_0>{});
|
||||
float& val_1 = src_buf(Number<src_offset_1>{});
|
||||
|
||||
//static_cast_float_to_bhalf_packed_v2(val_0, val_1);
|
||||
// Fill the packed value into the dst_vector
|
||||
});
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d)
|
||||
{
|
||||
// We need map the odd indices to the even indices, since
|
||||
// the even indices contain a packed bf16x2 value, where
|
||||
// the first value contains the bf16 value for the corresponding even index
|
||||
// and the second value contains the bf16 value for the odd index following the even index.
|
||||
// The odd indices are not used, so we can just ignore them.
|
||||
constexpr auto pair_index = idx_1d % I2;
|
||||
constexpr auto idx_src_1d = idx_1d - pair_index;
|
||||
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
|
||||
|
||||
union
|
||||
{
|
||||
float src_float;
|
||||
bhalf16_t src_bf16x2;
|
||||
} packed_value;
|
||||
|
||||
packed_value.src_float = src_buf[Number<src_offset>{}];
|
||||
|
||||
DstData v;
|
||||
base_transfer.get_element_op()(v, packed_value.src_bf16x2[pair_index.value]);
|
||||
dst_vector.template AsType<DstData>()(I0) = v;
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, base_transfer_.get_dst_coord(), make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, base_transfer.GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, base_transfer_.get_dst_coord(), dst_reset_step);
|
||||
}
|
||||
}
|
||||
private:
|
||||
BaseTransfer base_transfer_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
@@ -35,7 +175,6 @@ template <typename SrcData,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
bool PackedInput = false,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
{
|
||||
@@ -45,9 +184,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
static constexpr bool float_input_and_bf16_output_ =
|
||||
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
@@ -110,60 +246,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
if constexpr (PackedInput && float_input_and_bf16_output_)
|
||||
{
|
||||
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
|
||||
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d)
|
||||
{
|
||||
// We need map the odd indices to the even indices, since
|
||||
// the even indices contain a packed bf16x2 value, where
|
||||
// the first value contains the bf16 value for the corresponding even index
|
||||
// and the second value contains the bf16 value for the odd index following the even index.
|
||||
// The odd indices are not used, so we can just ignore them.
|
||||
constexpr auto pair_index = idx_1d % I2;
|
||||
constexpr auto idx_src_1d = idx_1d - pair_index;
|
||||
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
|
||||
|
||||
union
|
||||
{
|
||||
float src_float;
|
||||
bhalf16_t src_bf16x2;
|
||||
} packed_value;
|
||||
|
||||
packed_value.src_float = src_buf[Number<src_offset>{}];
|
||||
|
||||
DstData v;
|
||||
element_op_(v, packed_value.src_bf16x2[pair_index.value]);
|
||||
dst_vector.template AsType<DstData>()(I0) = v;
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
@@ -196,7 +279,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
@@ -246,7 +328,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
DestCoord& get_dest_coord() { return dst_coord_; }
|
||||
|
||||
const ElementwiseOperation& get_element_op() const { return element_op_; }
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
Reference in New Issue
Block a user