Add the vectorized option for packed cast.

This commit is contained in:
Ville Pietilä
2025-08-26 12:31:01 +00:00
parent 905cfb6623
commit 2302ea9bc6
5 changed files with 234 additions and 55 deletions

View File

@@ -898,7 +898,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -935,8 +936,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
1,
InMemoryDataOperationEnum::Set,
1,
true>
>;
true>>;
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
@@ -1007,20 +1007,20 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// if constexpr (is_gfx650_and_bf16_output())
// {
// auto c_thread_packed_cast = PackedCastV2<
// M2,
// M4,
// CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle
// >{};
// c_thread_packed_cast.Run(
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
// c_thread_buf // source buffer
// );
// }
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
@@ -1308,7 +1308,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1345,8 +1346,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
1,
InMemoryDataOperationEnum::Set,
1,
true>
>;
true>>;
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
@@ -1417,20 +1417,20 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// if constexpr (is_gfx650_and_bf16_output())
// {
// auto c_thread_packed_cast = PackedCastV2<
// M2,
// M4,
// CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle
// >{};
// c_thread_packed_cast.Run(
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
// c_thread_buf // source buffer
// );
// }
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,

View File

@@ -1646,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
ThreadwiseTensorSliceTransfer_v1r3_vectorized<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -2093,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
ThreadwiseTensorSliceTransfer_v1r3_vectorized<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),

View File

@@ -897,7 +897,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
@@ -1002,20 +1002,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// make sure it's safe to do ds_write
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle
>{};
c_thread_packed_cast.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct)
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin
c_thread_buf // source buffer
);
}
// if constexpr (is_gfx650_and_bf16_output())
// {
// auto c_thread_packed_cast = PackedCastV2<
// M2,
// M4,
// CShuffleMRepeatPerShuffle,
// CShuffleNRepeatPerShuffle
// >{};
// c_thread_packed_cast.Run(
// c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct)
// make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin
// c_thread_buf // source buffer
// );
// }
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(

View File

@@ -438,6 +438,185 @@ private:
true> thread_scratch_;
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is StaticBuffer
// 3. SrcSliceOrginIdx is known at compile-time
// 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
InMemoryDataOperationEnum DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun,
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3_vectorized
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr bool SnakedAccess = false;
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_vectorized(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const ElementwiseOperation&)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx))
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
// Assert that elementwise op is pass through.
static_assert(
std::is_same_v<remove_cvref_t<ElementwiseOperation>, ck::tensor_operation::element_wise::PassThrough>,
"wrong! ElementwiseOperation must be PassThrough");
// For now, SrcData must be float and DstData must be ck::bhalf_t
static_assert(std::is_same_v<SrcData, float>,
"wrong! SrcData must be float");
static_assert(std::is_same_v<DstData, ck::bhalf_t>,
"wrong! DstData must be bhalf_t");
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>,
SnakedAccess>;
static_assert(2 == SpaceFillingCurve::ScalarPerVector, "wrong!2 != SpaceFillingCurve::ScalarPerVector");
ck::bhalf2_t dst_vector;
using dst_vector_t = ck::bhalf2_t;
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d)
{
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
constexpr auto idx_src_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
constexpr auto idx_src_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md + src_scalar_step_in_vector);
const float val_0 = src_buf[Number<idx_src_0>{}];
const float val_1 = src_buf[Number<idx_src_1>{}];
dst_vector = bf16x2_convert_rne<ck::bhalf2_t, float>(val_0, val_1);
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector);
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>,
SnakedAccess>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
DstCoord dst_coord_;
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time

View File

@@ -106,10 +106,10 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, f
bhalf2_t bf16x2;
} converter;
// typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t;
// typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t;
// converter.bf16x2 = __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t);
converter.bf16x2 = {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
typedef __attribute__((__vector_size__(4))) __bf16 llvm_bf16x2_t;
typedef __attribute__((__vector_size__(8))) float llvm_fp32x2_t;
converter.bf16x2 = __builtin_convertvector(llvm_fp32x2_t{x, y}, llvm_bf16x2_t);
//converter.bf16x2 = {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
x = converter.fp32;
}