mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
Fix packed cast tensor slice transfer.
This commit is contained in:
@@ -1646,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -1763,20 +1763,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
// if constexpr (is_gfx650_and_bf16_output())
|
||||
// {
|
||||
// auto c_thread_packed_cast = PackedCastV2<
|
||||
// M2,
|
||||
// M4,
|
||||
// CShuffleMXdlPerWavePerShuffle,
|
||||
// CShuffleNXdlPerWavePerShuffle
|
||||
// >{};
|
||||
// c_thread_packed_cast.Run(
|
||||
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
// c_thread_buf // source buffer
|
||||
// );
|
||||
// }
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
@@ -2093,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -2204,20 +2204,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
// if constexpr (is_gfx650_and_bf16_output())
|
||||
// {
|
||||
// auto c_thread_packed_cast = PackedCastV2<
|
||||
// M2,
|
||||
// M4,
|
||||
// CShuffleMXdlPerWavePerShuffle,
|
||||
// CShuffleNXdlPerWavePerShuffle
|
||||
// >{};
|
||||
// c_thread_packed_cast.Run(
|
||||
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
// c_thread_buf // source buffer
|
||||
// );
|
||||
// }
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
@@ -235,12 +235,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr index_t SrcScalarPerVector = 1;
|
||||
static constexpr index_t BufScalarPerVector = 2;
|
||||
|
||||
// We cannot use SnakedCurve for packed cast, otherwise the vectorized store will not work correctly.
|
||||
static constexpr bool SnakedCurve = false;
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
@@ -295,17 +291,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
// SFC to access the source buffer and destination buffers.
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>,
|
||||
SnakedCurve>;
|
||||
|
||||
typename vector_type_maker<DstData, BufScalarPerVector>::type buf_vector;
|
||||
using buf_vector_t = typename vector_type_maker<DstData, BufScalarPerVector>::type::type;
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
@@ -329,25 +320,35 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
const float val_0 = src_buf[Number<src_offset_0>{}];
|
||||
const float val_1 = src_buf[Number<src_offset_1>{}];
|
||||
|
||||
buf_vector.template AsType<buf_vector_t>()(I0) = bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
|
||||
const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from buf_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, buf_vector_t>(
|
||||
// Store the first of the packed values
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
buf_vector.template AsType<buf_vector_t>()[Number<0>{}]);
|
||||
packed_value[0]);
|
||||
|
||||
// Move to next dst coordinate
|
||||
constexpr auto forward_step_0 = SpaceFillingCurve::GetForwardStep(idx_1d_0);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_,
|
||||
make_tensor_coordinate_step(dst_desc, forward_step_0));
|
||||
|
||||
// Store the second of the packed values
|
||||
dst_buf.template Update<DstInMemOp, ck::bhalf_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
packed_value[1]);
|
||||
|
||||
// Move to next dst coordinate, unless this was the last pair
|
||||
if constexpr(i_pair.value != num_pairs - 1)
|
||||
{
|
||||
// Move two steps forward in the space-filling curve.
|
||||
// This works only if we don't use the snaked access pattern.
|
||||
constexpr auto forward_step_md = SpaceFillingCurve::GetStepBetween(idx_1d_0, Number<idx_1d_1 + 1>{});
|
||||
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d_1);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step_md));
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
@@ -369,8 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3_packed_cast
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>,
|
||||
SnakedCurve>;
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
|
||||
@@ -106,6 +106,9 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f
|
||||
bhalf2_t bf16x2;
|
||||
} converter;
|
||||
|
||||
// typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t;
|
||||
// typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t;
|
||||
// converter.bf16x2 = __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t);
|
||||
converter.bf16x2 = {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
|
||||
x = converter.fp32;
|
||||
}
|
||||
@@ -119,10 +122,11 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f
|
||||
* @return Converted vector of 2 bhalf_t.
|
||||
*/
|
||||
template<>
|
||||
inline __host__ __device__ constexpr bhalf2_t bf16x2_convert_rne<bhalf2_t, float>(float x, float y)
|
||||
inline __host__ __device__ bhalf2_t bf16x2_convert_rne<bhalf2_t, float>(float x, float y)
|
||||
{
|
||||
// for gfx950, the compiler will use device instruction v_cvt_pk_bf16_f32 to execute packed cast.
|
||||
return {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
|
||||
typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t;
|
||||
typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t;
|
||||
return __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
|
||||
|
||||
Reference in New Issue
Block a user