mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
Use thread scratch buffer in bf16 conversion.
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
@@ -273,10 +273,21 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcSliceOriginIdx& src_slice_origin_idx,
|
||||
const SrcBuffer& src_buf)
|
||||
__device__ void RunRead(const SrcBuffer& src_buf)
|
||||
{
|
||||
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
@@ -309,8 +320,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
|
||||
|
||||
// Store the packed value into the thread scratch buffer
|
||||
thread_scratch_tuple_(0).template Update<InMemoryDataOperationEnum::Set, ck::bhalf2_t>(
|
||||
i_pair * 2, true, packed_value);
|
||||
thread_scratch_.template SetAsType<ck::bhalf2_t>(Number<idx_1d_0>{}, packed_value);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -328,45 +338,25 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr index_t num_pairs = num_access / 2;
|
||||
constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
static_assert(num_access == buffer_size_, "wrong!num_access != buffer_size_");
|
||||
|
||||
// TODO: Enable also odd number of elements.
|
||||
static_assert(!has_odd_element, "wrong!Slice should have even number of elements.");
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d)
|
||||
{
|
||||
const ck::bhalf2_t packed_value =
|
||||
thread_scratch_tuple_[0].template Get<ck::bhalf2_t>(i_pair * 2);
|
||||
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
|
||||
const ck::bhalf_t val = thread_scratch_.template GetAsType<ck::bhalf_t>(Number<idx_1d>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// Store the first of the packed values
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
packed_value[0]);
|
||||
val);
|
||||
|
||||
// Move to next dst coordinate
|
||||
constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_,
|
||||
make_tensor_coordinate_step(dst_desc, forward_step_0));
|
||||
|
||||
// Store the second of the packed values
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
packed_value[1]);
|
||||
|
||||
// Move to next dst coordinate, unless this was the last pair
|
||||
if constexpr(i_pair.value != num_pairs - 1)
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d_1);
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
@@ -389,104 +379,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
RunRead(src_desc, src_slice_origin_idx, src_buf);
|
||||
RunRead<SrcSliceOriginIdx,SrcBuffer>(src_buf);
|
||||
RunWrite(dst_desc, dst_buf);
|
||||
// static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
// "wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
// static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
// "wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
// static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// // SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
// constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
// constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
// using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
// DimAccessOrder,
|
||||
// remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
// static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
// constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
// constexpr index_t num_pairs = num_access / 2;
|
||||
// constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
|
||||
// // TODO: Enable also odd number of elements.
|
||||
// static_assert(!has_odd_element, "wrong!Slice should have even number of elements.");
|
||||
|
||||
// static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
// {
|
||||
// constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
// constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
// constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
|
||||
// constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
|
||||
|
||||
// constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
|
||||
// constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
|
||||
|
||||
// const float val_0 = src_buf[Number<src_offset_0>{}];
|
||||
// const float val_1 = src_buf[Number<src_offset_1>{}];
|
||||
|
||||
// const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
|
||||
|
||||
// const bool is_dst_valid =
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// // Store the first of the packed values
|
||||
// dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
// dst_coord_.GetOffset(),
|
||||
// is_dst_valid,
|
||||
// packed_value[0]);
|
||||
|
||||
// // Move to next dst coordinate
|
||||
// constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
|
||||
// move_tensor_coordinate(
|
||||
// dst_desc, dst_coord_,
|
||||
// make_tensor_coordinate_step(dst_desc, forward_step_0));
|
||||
|
||||
// // Store the second of the packed values
|
||||
// dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
// dst_coord_.GetOffset(),
|
||||
// is_dst_valid,
|
||||
// packed_value[1]);
|
||||
|
||||
// // Move to next dst coordinate, unless this was the last pair
|
||||
// if constexpr(i_pair.value != num_pairs - 1)
|
||||
// {
|
||||
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d_1);
|
||||
// move_tensor_coordinate(
|
||||
// dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
// }
|
||||
// });
|
||||
|
||||
// // move dst coordinate back to slice origin (or not)
|
||||
// if constexpr(DstResetCoordinateAfterRun)
|
||||
// {
|
||||
// const auto dst_reset_step =
|
||||
// make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
// move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
@@ -525,65 +421,21 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetThreadScratchDescriptor()
|
||||
static constexpr index_t ScratchVectorSize = 2;
|
||||
|
||||
__device__ static constexpr index_t GetBufferSize()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, 2>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(src_access_lengths), Number<2>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
return reduce_on_sequence(SliceLengths{}, math::multiplies{}, Number<1>{});
|
||||
}
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
|
||||
static constexpr auto thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
using ThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
ck::bhalf_t
|
||||
2,
|
||||
decltype(thread_scratch_desc_),
|
||||
true>;
|
||||
StaticallyIndexedArray<ThreadScratch, 1> thread_scratch_tuple_;
|
||||
static constexpr auto buffer_size_ = GetBufferSize();
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
ck::bhalf_t,
|
||||
buffer_size_ / ScratchVectorSize,
|
||||
ScratchVectorSize,
|
||||
true> thread_scratch_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
|
||||
Reference in New Issue
Block a user