mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
WMMA gemm_add_relu_add_layernorm (#2989)
* Summary:
- Refactor epilogue (with CShuffle) to support fused operations:
- EpilogueCShuffleBase holds common parts
- EpilogueCShuffle: runs CShuffle and write out
- EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out
- Extend thread transfer v7r3:
- Support for intermediate data type different from src and dst type
- New functionality to write to dst buffer and keep data (to be able to use them for additional operations)
* Adress review comments
This commit is contained in:
@@ -43,7 +43,8 @@ template <typename SrcDatas,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
index_t NumThreadScratch = 1>
|
||||
index_t NumThreadScratch = 1,
|
||||
typename InterDatas = DstDatas>
|
||||
struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -153,7 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
// loop over space-filling curve
|
||||
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
|
||||
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
|
||||
auto elm_vectors = generate_vectors<InterDatas, SrcScalarPerVector>();
|
||||
|
||||
bool oob_val = true;
|
||||
|
||||
@@ -226,9 +227,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
using InterData = remove_cvref_t<tuple_element_t<iDst.value, InterDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
using elem_op_vec_t =
|
||||
typename vector_type<InterData, elem_op_vec_len>::type;
|
||||
|
||||
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
@@ -297,17 +299,17 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
__device__ void
|
||||
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
|
||||
using InterData = remove_cvref_t<decltype(InterDatas{}[I0])>;
|
||||
|
||||
using ElmThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
InterData,
|
||||
SrcScalarPerVector,
|
||||
decltype(GetSrcThreadScratchDescriptor()),
|
||||
true>;
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
InterData,
|
||||
DstScalarPerVector,
|
||||
decltype(GetDstThreadScratchDescriptor()),
|
||||
true>;
|
||||
@@ -319,11 +321,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
|
||||
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
((is_same<half_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<f8_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<int8_t, remove_cvref_t<InterData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
@@ -356,8 +358,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
using src_vector_t = vector_type_maker_t<InterData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<InterData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
@@ -380,7 +382,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
transpose_vectors<InterData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
@@ -393,6 +395,104 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
// DstVgprDescs: Tuple<const DstVgprDesc0&, const DstVgprDesc1&, ...>
|
||||
// DstVgprBuffers: Tuple<DstVgprBuffer0&, DstVgprBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
typename DstVgprDescs,
|
||||
typename DstVgprBuffers,
|
||||
index_t ThreadScratchId = 0,
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void
|
||||
RunWriteAndStoreVgpr(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
const DstVgprDescs&,
|
||||
DstVgprBuffers dst_vgpr_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// Same functionality of RunWrite but additionally store internal Vgpr in dst_vgpr_buf
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
// Vgpr buffer origin is set internally to 0
|
||||
constexpr auto dst_slice_origin_idx =
|
||||
generate_tuple([&](auto) { return I0; }, Number<nDim>{});
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
using DstData = remove_cvref_t<decltype(DstDatas{}[i])>;
|
||||
using InterData = remove_cvref_t<decltype(InterDatas{}[i])>;
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
|
||||
dst_vector.template AsType<DstData>()(j) =
|
||||
type_convert<DstData>(dst_vectors[i].template AsType<InterData>()[j]);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
dst_coords_[i]);
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coords_[i].GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// store Vgpr
|
||||
using DstVgprDesc = remove_cvref_t<decltype(DstVgprDescs{}.At(i))>;
|
||||
static_assert(DstVgprDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
constexpr auto dst_vgpr_desc = DstVgprDesc{};
|
||||
|
||||
constexpr auto src_data_idx = DstSpaceFillingCurve::GetIndex(iAccess);
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_vgpr_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
|
||||
src_data_idx + j * dst_scalar_step_in_vector);
|
||||
|
||||
dst_vgpr_buf(I0)(Number<dst_offset>{}) =
|
||||
is_dst_valid ? dst_vectors[i].template AsType<InterData>()[j]
|
||||
: NumericLimits<InterData>::QuietNaN();
|
||||
});
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
if constexpr(iAccess.value != dst_num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
move_tensor_coordinate(dst_descs[i],
|
||||
dst_coords_(i),
|
||||
make_tensor_coordinate_step(dst_descs[i], forward_step));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
|
||||
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
|
||||
template <typename DstBuffers,
|
||||
@@ -402,6 +502,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
static_assert(is_same_v<InterDatas, DstDatas>,
|
||||
"RunWrite doesn't support inter data type different from dst data type");
|
||||
|
||||
OOBCheck(thread_scratch_id);
|
||||
TransposeFromElmToDst(thread_scratch_id);
|
||||
|
||||
@@ -630,8 +733,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
|
||||
private:
|
||||
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
|
||||
using ElmVectorsType = decltype(generate_vectors<InterDatas, SrcScalarPerVector>());
|
||||
using DstVectorsType = decltype(generate_vectors<InterDatas, DstScalarPerVector>());
|
||||
|
||||
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
|
||||
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
Reference in New Issue
Block a user