WMMA gemm_add_relu_add_layernorm (#2989)

* Summary:

 - Refactor epilogue (with CShuffle) to support fused operations:
    - EpilogueCShuffleBase holds common parts
    - EpilogueCShuffle: runs CShuffle and write out
    - EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out

 - Extend thread transfer v7r3:
    - Support for intermediate data type different from src and dst type
    - New functionality to write to dst buffer and keep data (to be able to use them for additional operations)

* Adress review comments
This commit is contained in:
Enrico Degregori
2025-10-31 19:19:26 +01:00
committed by GitHub
parent e9596228ff
commit 4ebc48a3cd
23 changed files with 2678 additions and 332 deletions

View File

@@ -43,7 +43,8 @@ template <typename SrcDatas,
index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
index_t NumThreadScratch = 1>
index_t NumThreadScratch = 1,
typename InterDatas = DstDatas>
struct ThreadwiseTensorSliceTransfer_v7r3
{
static constexpr auto I0 = Number<0>{};
@@ -153,7 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<InterDatas, SrcScalarPerVector>();
bool oob_val = true;
@@ -226,9 +227,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using InterData = remove_cvref_t<tuple_element_t<iDst.value, InterDatas>>;
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
using elem_op_vec_t =
typename vector_type<InterData, elem_op_vec_len>::type;
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
@@ -297,17 +299,17 @@ struct ThreadwiseTensorSliceTransfer_v7r3
__device__ void
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
using InterData = remove_cvref_t<decltype(InterDatas{}[I0])>;
using ElmThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
InterData,
SrcScalarPerVector,
decltype(GetSrcThreadScratchDescriptor()),
true>;
using DstThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
InterData,
DstScalarPerVector,
decltype(GetDstThreadScratchDescriptor()),
true>;
@@ -319,11 +321,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
((is_same<half_t, remove_cvref_t<InterData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
(is_same<f8_t, remove_cvref_t<InterData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
(is_same<int8_t, remove_cvref_t<InterData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
@@ -356,8 +358,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
using src_vector_t = vector_type_maker_t<InterData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<InterData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
@@ -380,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
transpose_vectors<InterData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
@@ -393,6 +395,104 @@ struct ThreadwiseTensorSliceTransfer_v7r3
dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
// DstVgprDescs: Tuple<const DstVgprDesc0&, const DstVgprDesc1&, ...>
// DstVgprBuffers: Tuple<DstVgprBuffer0&, DstVgprBuffer1&, ...>
template <typename DstBuffers,
typename DstVgprDescs,
typename DstVgprBuffers,
index_t ThreadScratchId = 0,
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void
RunWriteAndStoreVgpr(const DstDescs& dst_descs,
DstBuffers dst_bufs,
const DstVgprDescs&,
DstVgprBuffers dst_vgpr_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// Same functionality of RunWrite but additionally store internal Vgpr in dst_vgpr_buf
OOBCheck(thread_scratch_id);
TransposeFromElmToDst(thread_scratch_id);
// Vgpr buffer origin is set internally to 0
constexpr auto dst_slice_origin_idx =
generate_tuple([&](auto) { return I0; }, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
// loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
static_for<0, nDst, 1>{}([&](auto i) {
// copy data from buf_vectors into dst_bufs
using DstData = remove_cvref_t<decltype(DstDatas{}[i])>;
using InterData = remove_cvref_t<decltype(InterDatas{}[i])>;
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
dst_vector.template AsType<DstData>()(j) =
type_convert<DstData>(dst_vectors[i].template AsType<InterData>()[j]);
});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_[i].GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[I0]);
// store Vgpr
using DstVgprDesc = remove_cvref_t<decltype(DstVgprDescs{}.At(i))>;
static_assert(DstVgprDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
constexpr auto dst_vgpr_desc = DstVgprDesc{};
constexpr auto src_data_idx = DstSpaceFillingCurve::GetIndex(iAccess);
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
constexpr index_t dst_offset =
dst_vgpr_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
src_data_idx + j * dst_scalar_step_in_vector);
dst_vgpr_buf(I0)(Number<dst_offset>{}) =
is_dst_valid ? dst_vectors[i].template AsType<InterData>()[j]
: NumericLimits<InterData>::QuietNaN();
});
});
// move coordinate
if constexpr(iAccess.value != dst_num_access - 1)
{
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nDst, 1>{}([&](auto i) {
move_tensor_coordinate(dst_descs[i],
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step));
});
}
});
static_for<0, nDst, 1>{}([&](auto i) {
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep());
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
}
});
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
@@ -402,6 +502,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
static_assert(is_same_v<InterDatas, DstDatas>,
"RunWrite doesn't support inter data type different from dst data type");
OOBCheck(thread_scratch_id);
TransposeFromElmToDst(thread_scratch_id);
@@ -630,8 +733,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
private:
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
using ElmVectorsType = decltype(generate_vectors<InterDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<InterDatas, DstScalarPerVector>());
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();