Use thread scratch buffer in bf16 conversion.

This commit is contained in:
Ville Pietilä
2025-08-26 09:33:50 +00:00
parent a26f66171b
commit 905cfb6623

View File

@@ -7,7 +7,7 @@
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
@@ -273,10 +273,21 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
}
template <typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcSliceOriginIdx& src_slice_origin_idx,
const SrcBuffer& src_buf)
__device__ void RunRead(const SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
@@ -309,8 +320,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
// Store the packed value into the thread scratch buffer
thread_scratch_tuple_(0).template Update<InMemoryDataOperationEnum::Set, ck::bhalf2_t>(
i_pair * 2, true, packed_value);
thread_scratch_.template SetAsType<ck::bhalf2_t>(Number<idx_1d_0>{}, packed_value);
});
}
@@ -328,45 +338,25 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr index_t num_pairs = num_access / 2;
constexpr bool has_odd_element = (num_access % 2 == 1);
static_assert(num_access == buffer_size_, "wrong!num_access != buffer_size_");
// TODO: Enable also odd number of elements.
static_assert(!has_odd_element, "wrong!Slice should have even number of elements.");
static_for<0, num_pairs, 1>{}([&](auto i_pair)
static_for<0, num_access, 1>{}([&](auto idx_1d)
{
const ck::bhalf2_t packed_value =
thread_scratch_tuple_[0].template Get<ck::bhalf2_t>(i_pair * 2);
constexpr auto idx_1d_0 = I2 * i_pair;
constexpr auto idx_1d_1 = I2 * i_pair + I1;
const ck::bhalf_t val = thread_scratch_.template GetAsType<ck::bhalf_t>(Number<idx_1d>{});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// Store the first of the packed values
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
dst_coord_.GetOffset(),
is_dst_valid,
packed_value[0]);
val);
// Move to next dst coordinate
constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
move_tensor_coordinate(
dst_desc, dst_coord_,
make_tensor_coordinate_step(dst_desc, forward_step_0));
// Store the second of the packed values
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
dst_coord_.GetOffset(),
is_dst_valid,
packed_value[1]);
// Move to next dst coordinate, unless this was the last pair
if constexpr(i_pair.value != num_pairs - 1)
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d_1);
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
@@ -389,104 +379,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
RunRead(src_desc, src_slice_origin_idx, src_buf);
RunRead<SrcSliceOriginIdx,SrcBuffer>(src_buf);
RunWrite(dst_desc, dst_buf);
// static_assert(SrcDesc::IsKnownAtCompileTime(),
// "wrong! SrcDesc need to known at compile-time");
// static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
// "wrong! SrcSliceOrigin need to known at compile-time");
// static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// // SrcDesc and src_slice_origin_idx are known at compile-time
// constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
// constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
// using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
// DimAccessOrder,
// remove_cv_t<decltype(dst_scalar_per_access)>>;
// static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
// constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
// constexpr index_t num_pairs = num_access / 2;
// constexpr bool has_odd_element = (num_access % 2 == 1);
// // TODO: Enable also odd number of elements.
// static_assert(!has_odd_element, "wrong!Slice should have even number of elements.");
// static_for<0, num_pairs, 1>{}([&](auto i_pair)
// {
// constexpr auto idx_1d_0 = I2 * i_pair;
// constexpr auto idx_1d_1 = I2 * i_pair + I1;
// constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
// constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
// constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
// constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
// const float val_0 = src_buf[Number<src_offset_0>{}];
// const float val_1 = src_buf[Number<src_offset_1>{}];
// const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
// const bool is_dst_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// // Store the first of the packed values
// dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
// dst_coord_.GetOffset(),
// is_dst_valid,
// packed_value[0]);
// // Move to next dst coordinate
// constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
// move_tensor_coordinate(
// dst_desc, dst_coord_,
// make_tensor_coordinate_step(dst_desc, forward_step_0));
// // Store the second of the packed values
// dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
// dst_coord_.GetOffset(),
// is_dst_valid,
// packed_value[1]);
// // Move to next dst coordinate, unless this was the last pair
// if constexpr(i_pair.value != num_pairs - 1)
// {
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d_1);
// move_tensor_coordinate(
// dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
// }
// });
// // move dst coordinate back to slice origin (or not)
// if constexpr(DstResetCoordinateAfterRun)
// {
// const auto dst_reset_step =
// make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
// move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
// }
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto dst_scalar_per_access = generate_sequence(
@@ -525,65 +421,21 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
__device__ static constexpr auto GetThreadScratchDescriptor()
static constexpr index_t ScratchVectorSize = 2;
__device__ static constexpr index_t GetBufferSize()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, 2>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<2>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return reduce_on_sequence(SliceLengths{}, math::multiplies{}, Number<1>{});
}
private:
DstCoord dst_coord_;
static constexpr auto thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
using ThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
ck::bhalf_t
2,
decltype(thread_scratch_desc_),
true>;
StaticallyIndexedArray<ThreadScratch, 1> thread_scratch_tuple_;
static constexpr auto buffer_size_ = GetBufferSize();
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
ck::bhalf_t,
buffer_size_ / ScratchVectorSize,
ScratchVectorSize,
true> thread_scratch_;
};
// Assume: