Fix packed cast tensor slice transfer.

This commit is contained in:
Ville Pietilä
2025-08-22 10:26:59 +00:00
parent de93a48b04
commit 4e7f9f7908
3 changed files with 60 additions and 56 deletions

View File

@@ -1646,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1763,20 +1763,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// if constexpr (is_gfx650_and_bf16_output())
// {
// auto c_thread_packed_cast = PackedCastV2<
// M2,
// M4,
// CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle
// >{};
// c_thread_packed_cast.Run(
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
// c_thread_buf // source buffer
// );
// }
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
@@ -2093,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -2204,20 +2204,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// if constexpr (is_gfx650_and_bf16_output())
// {
// auto c_thread_packed_cast = PackedCastV2<
// M2,
// M4,
// CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle
// >{};
// c_thread_packed_cast.Run(
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
// c_thread_buf // source buffer
// );
// }
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,

View File

@@ -235,12 +235,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t SrcScalarPerVector = 1;
static constexpr index_t BufScalarPerVector = 2;
// We cannot use SnakedCurve for packed cast, otherwise the vectorized store will not work correctly.
static constexpr bool SnakedCurve = false;
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
@@ -295,17 +291,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
// SFC to access the source buffer and destination buffers.
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>,
SnakedCurve>;
typename vector_type_maker<DstData, BufScalarPerVector>::type buf_vector;
using buf_vector_t = typename vector_type_maker<DstData, BufScalarPerVector>::type::type;
remove_cv_t<decltype(dst_scalar_per_access)>>;
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
@@ -329,25 +320,35 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
const float val_0 = src_buf[Number<src_offset_0>{}];
const float val_1 = src_buf[Number<src_offset_1>{}];
buf_vector.template AsType<buf_vector_t>()(I0) = bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from buf_vector into dst_buf
dst_buf.template Update<DstInMemOp, buf_vector_t>(
// Store the first of the packed values
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
dst_coord_.GetOffset(),
is_dst_valid,
buf_vector.template AsType<buf_vector_t>()[Number<0>{}]);
packed_value[0]);
// Move to next dst coordinate
constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
move_tensor_coordinate(
dst_desc, dst_coord_,
make_tensor_coordinate_step(dst_desc, forward_step_0));
// Store the second of the packed values
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
dst_coord_.GetOffset(),
is_dst_valid,
packed_value[1]);
// Move to next dst coordinate, unless this was the last pair
if constexpr(i_pair.value != num_pairs - 1)
{
// Move two steps forward in the space-filling curve.
// This works only if we don't use the snaked access pattern.
constexpr auto forward_step_md = SpaceFillingCurve::GetStepBetween(idx_1d_0, Number<idx_1d_1 + 1>{});
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d_1);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_md));
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
@@ -369,8 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>,
SnakedCurve>;
remove_cv_t<decltype(dst_scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)

View File

@@ -106,6 +106,9 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f
bhalf2_t bf16x2;
} converter;
// typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t;
// typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t;
// converter.bf16x2 = __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t);
converter.bf16x2 = {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
x = converter.fp32;
}
@@ -119,10 +122,11 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f
* @return Converted vector of 2 bhalf_t.
*/
template<>
inline __host__ __device__ constexpr bhalf2_t bf16x2_convert_rne<bhalf2_t, float>(float x, float y)
inline __host__ __device__ bhalf2_t bf16x2_convert_rne<bhalf2_t, float>(float x, float y)
{
// for gfx950, the compiler will use device instruction v_cvt_pk_bf16_f32 to execute packed cast.
return {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t;
typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t;
return __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed