Add back the separate packed cast step.

This commit is contained in:
Ville Pietilä
2025-08-20 11:31:07 +00:00
parent 6fbe1895f1
commit de93a48b04
7 changed files with 635 additions and 47 deletions

View File

@@ -9,6 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -897,7 +898,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3_pass_through<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1006,6 +1007,21 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -1292,7 +1308,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3_pass_through<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1401,6 +1417,21 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),

View File

@@ -9,6 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -1645,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1762,6 +1763,21 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -2077,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -2188,6 +2204,21 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
@@ -896,7 +897,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
@@ -1001,6 +1002,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// make sure it's safe to do ds_write
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle
>{};
c_thread_packed_cast.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct)
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin
c_thread_buf // source buffer
);
}
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,

View File

@@ -0,0 +1,141 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCast
{
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// Calculate total elements in this slice
constexpr index_t elements_per_slice =
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr {
// We know that the access order is
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, 1, 1, M2, 1, M4, 1>
constexpr index_t m4_offset = idx.value % M4;
constexpr index_t m2_offset = (idx.value / M4) % M2;
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
return make_tuple(
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
Number<0>{}, // this dim has unit size
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
Number<0>{} // this dim has unit size
);
};
constexpr index_t num_pairs = elements_per_slice / 2;
constexpr bool has_odd_element = (elements_per_slice % 2 == 1);
static_assert(!has_odd_element, "PackedCast does not support odd number of elements");
static_for<0, num_pairs, 1>{}([&](auto pair_idx) {
constexpr auto idx_0 = Number<pair_idx * 2>{};
constexpr auto idx_1 = Number<pair_idx * 2 + 1>{};
constexpr auto coord_0 = calculate_coords(idx_0);
constexpr auto coord_1 = calculate_coords(idx_1);
constexpr auto offset_0 = src_desc.CalculateOffset(coord_0);
constexpr auto offset_1 = src_desc.CalculateOffset(coord_1);
float& val_0 = src_buf(Number<offset_0>{});
const float val_1 = src_buf[Number<offset_1>{}];
static_cast_float_to_bhalf_packed_v2(val_0, val_1);
});
};
};
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCastV2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
using SliceLengths = Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>;
using DimAccessOrder = Sequence<0, 1, 2, 3, 4, 5, 6, 7>;
using DstScalarPerAccess = Sequence<1, 1, 1, 1, 1, 1, 1, 1>;
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
DstScalarPerAccess>;
static_assert(SpaceFillingCurve::ScalarPerVector == 1,
"wrong! SpaceFillingCurve::ScalarPerVector must be 1 for PackedCastV2");
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr index_t num_pairs = num_access / 2;
constexpr bool has_odd_element = (num_access % 2 == 1);
static_assert(!has_odd_element, "PackedCastV2 does not support odd number of elements");
static_for<0, num_pairs, 1>{}([&](auto i_pair)
{
constexpr auto idx_1d_0 = I2 * i_pair;
constexpr auto idx_1d_1 = I2 * i_pair + I1;
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
float& val_0 = src_buf(Number<src_offset_0>{});
const float val_1 = src_buf[Number<src_offset_1>{}];
static_cast_float_to_bhalf_packed_v2(val_0, val_1);
});
};
};
}

View File

@@ -13,6 +13,200 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is StaticBuffer
// 3. SrcSliceOrginIdx is known at compile-time
// 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
InMemoryDataOperationEnum DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun,
bool PackedInput = false,
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = SliceLengths::Size();
static constexpr bool float_input_and_bf16_output_ =
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_pass_through(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const ElementwiseOperation& element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
element_op_{element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
static_for<0, num_access, 1>{}([&](auto idx_1d)
{
// We need map the odd indices to the even indices, since
// the even indices contain a packed bf16x2 value, where
// the first value contains the bf16 value for the corresponding even index
// and the second value contains the bf16 value for the odd index following the even index.
// The odd indices are not used, so we can just ignore them.
constexpr auto pair_index = idx_1d % I2;
constexpr auto idx_src_1d = idx_1d - pair_index;
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
union
{
float src_float;
bhalf2_t src_bf16x2;
} packed_value;
packed_value.src_float = src_buf[Number<src_offset>{}];
DstData v;
element_op_(v, packed_value.src_bf16x2[pair_index.value]);
dst_vector.template AsType<DstData>()(I0) = v;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
DstCoord dst_coord_;
const ElementwiseOperation element_op_;
}; // namespace ThreadwiseTensorSliceTransfer_v1r3_pass_through
// Assume:
// 1. src:
@@ -292,7 +486,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
@@ -305,12 +499,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
dst_vector.template AsType<DstData>()(i) = v;
});
@@ -379,6 +574,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
DstCoord dst_coord_;
const ElementwiseOperation element_op_;

View File

@@ -91,6 +91,25 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(fl
#endif
}
/**
* @brief Converts two floats into a vector of 2 16-bit bfloat types (bhalf2_t) using
* rounding to nearest/even (RNE).
*
* @param x First float value. The packed value is stored here.
* @param y Second float value.
*/
inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, float y)
{
union
{
float fp32;
bhalf2_t bf16x2;
} converter;
converter.bf16x2 = {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
x = converter.fp32;
}
/**
* @brief Converts two floats into a vector of 2 16-bit bfloat types (bhalf2_t) using
* rounding to nearest/even (RNE).

View File

@@ -15,45 +15,6 @@
using ck::bhalf_t;
using ck::type_convert;
TEST(BHALF_T, Nan)
{
const uint16_t binary_bhalf_nan = 0x7FC0;
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
}
TEST(BHALF_T, Inf)
{
const uint16_t binary_bhalf_inf = 0x7F80;
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
}
TEST(BHALF_T, MantisaOverflow)
{
const float abs_tol = std::pow(2, -7);
const uint32_t val = 0x81FFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
}
TEST(BHALF_T, ExpOverflow)
{
const uint32_t val = 0xFF800000;
const float float_val = ck::bit_cast<float>(val);
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
}
TEST(BHALF_T, MantisaExpOverflow)
{
const uint32_t val = 0xFFFFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_TRUE(std::isnan(float_val));
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
}
__global__ void cast_roundtrip(const float2 input, float2* output)
{
const ck::bhalf2_t bhalf2_val = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(input.x, input.y);
@@ -74,10 +35,26 @@ __global__ void cast(const float input, float* output)
*output = type_convert<float>(bhalf_val);
}
__global__ void packed_cast_in_place(const float x1, const float x2, ck::bhalf2_t* output)
{
union
{
float src;
ck::bhalf2_t dst;
} converter;
float x = x1;
ck::static_cast_float_to_bhalf_packed_v2(x, x2);
converter.src = x;
*output = converter.dst;
}
enum struct CastMode : int
{
Standard = 0,
Packed = 1
Packed = 1,
PackedInPlace = 2
};
template <CastMode PackedCast, int NumElements>
@@ -119,6 +96,60 @@ __global__ void test_performance_kernel(float* input, ck::bhalf_t* output)
}
}
template <CastMode PackedCast, int NumElements>
__global__ void test_in_place_performance_kernel(float* input, ck::bhalf_t* output)
{
ck::bhalf_t buffer_bf16[NumElements];
float buffer_float[NumElements];
// Initialize input data
for(int i = 0; i < NumElements; i++)
{
buffer_float[i] = input[i];
}
if constexpr (PackedCast == CastMode::PackedInPlace)
{
union
{
float src;
ck::bhalf2_t dst;
} workspace;
for(int i = 0; i < NumElements; i++)
{
for (int j = 0; j < NumElements; j++)
{
int index = (i + j) % NumElements;
index = index < NumElements - 1 ? index : NumElements - 2;
workspace.src = buffer_float[i];
ck::static_cast_float_to_bhalf_packed_v2(workspace.src, buffer_float[j]);
ck::bhalf2_t* buffer_range = reinterpret_cast<ck::bhalf2_t*>(&buffer_bf16[index]);
*buffer_range = workspace.dst;
}
}
}
else
{
for(int i = 0; i < NumElements; i++)
{
for (int j = 0; j < NumElements; j++)
{
int index = (i + j) % NumElements;
index = index < NumElements - 1 ? index : NumElements - 2;
buffer_bf16[index] = ck::bf16_convert_rtn_base(buffer_float[i]);
buffer_bf16[index + 1] = ck::bf16_convert_rtn_base(buffer_float[j]);
}
}
}
// Copy results back to output
for(int i = 0; i < NumElements; i++)
{
output[i] = buffer_bf16[i];
}
}
template <int NumElements>
void run_performance_test()
{
@@ -165,6 +196,91 @@ void run_performance_test()
ASSERT_LT(packed_time, baseline_time);
}
template <int NumElements>
void run_in_place_performance_test()
{
float* input_dev;
ck::bhalf_t* output_dev;
std::vector<ck::bhalf_t> output_host(NumElements);
hip_check_error(hipMalloc(&input_dev, sizeof(float) * NumElements));
hip_check_error(hipMalloc(&output_dev, sizeof(ck::bhalf_t) * NumElements));
// Initialize input data on the device
std::vector<float> input_host(NumElements);
for (int i = 0; i < NumElements; i++)
{
input_host[i] = 3.14f * static_cast<float>(i) - 1.7f;
}
hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice));
StreamConfig stream_config;
stream_config.time_kernel_ = true;
auto baseline_kernel = test_in_place_performance_kernel<CastMode::Standard, NumElements>;
auto packed_kernel = test_in_place_performance_kernel<CastMode::PackedInPlace, NumElements>;
constexpr dim3 grid_size(1);
constexpr dim3 block_size(1);
constexpr size_t shared_mem_size = 0;
const float baseline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
const float packed_time = launch_and_time_kernel(stream_config, packed_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
// Cleanup
hip_check_error(hipFree(input_dev));
hip_check_error(hipFree(output_dev));
std::cout << "Packed cast time ( " << NumElements << " elements): " << packed_time << " ms" << std::endl;
std::cout << "Baseline cast time ( " << NumElements << " elements): " << baseline_time << " ms" << std::endl;
// Check if packed cast is faster than baseline
ASSERT_LT(packed_time, baseline_time);
}
TEST(BHALF_T, Nan)
{
const uint16_t binary_bhalf_nan = 0x7FC0;
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
}
TEST(BHALF_T, Inf)
{
const uint16_t binary_bhalf_inf = 0x7F80;
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
}
TEST(BHALF_T, MantisaOverflow)
{
const float abs_tol = std::pow(2, -7);
const uint32_t val = 0x81FFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
}
TEST(BHALF_T, ExpOverflow)
{
const uint32_t val = 0xFF800000;
const float float_val = ck::bit_cast<float>(val);
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
}
TEST(BHALF_T, MantisaExpOverflow)
{
const uint32_t val = 0xFFFFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_TRUE(std::isnan(float_val));
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
}
TEST(BHALF_T, Performance)
{
if (ck::get_device_name() == "gfx950")
@@ -365,3 +481,41 @@ TEST(BHALF_T, CastOnDevice)
ASSERT_NEAR(float_val_after_cast_host, -float_vals[idx], abs_tol);
}
}
TEST(BHALF_T, PackedCast_in_place)
{
const float v1 = 3.14f;
const float v2 = -1.618f;
ck::bhalf2_t* bhalf2_val_d;
hip_check_error(hipMalloc(&bhalf2_val_d, sizeof(ck::bhalf2_t)));
packed_cast_in_place<<<1, 1>>>(v1, v2, bhalf2_val_d);
hip_check_error(hipGetLastError());
ck::bhalf2_t bhalf2_val_h;
hip_check_error(hipMemcpy(&bhalf2_val_h, bhalf2_val_d, sizeof(ck::bhalf2_t), hipMemcpyDeviceToHost));
// Convert back to floats
const float fval1 = type_convert<float>(bhalf2_val_h[0]);
const float fval2 = type_convert<float>(bhalf2_val_h[1]);
const float abs_tol = std::pow(2, -7);
ASSERT_NEAR(fval1, v1, abs_tol);
ASSERT_NEAR(fval2, v2, abs_tol);
hip_check_error(hipFree(bhalf2_val_d));
}
TEST(BHALF_T, PackedCast_in_place_performance)
{
if (ck::get_device_name() == "gfx950")
{
run_in_place_performance_test<32>();
run_in_place_performance_test<64>();
run_in_place_performance_test<128>();
}
else
{
GTEST_SKIP() << "Packed cast performance test requires gfx950.";
}
}