mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
add vector_type support into thread_copy_v3r1 (#969)
* add vector_type support into thread_copy_v3r1 * remove unncessary type_convert * fixed datatype * fixed dataType * changed API with is_packx_invocable * changed example * add missing cmake file * fixed ci * fixed cmake --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
dst_vector_type op_r_v;
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
|
||||
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<src_vector_t>(
|
||||
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
|
||||
.template SetAsType<dst_vector_t>(src_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) =
|
||||
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
#else
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
// TODO make this logic more generic for more sub-dword datatype
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
DstData dst_v;
|
||||
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
|
||||
dst_thread_scratch_(idx) = dst_v;
|
||||
});
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
SrcData,
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
using SrcThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData, // apply data_convert with SrcThreadScratch
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
|
||||
@@ -132,9 +132,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
Number<num>{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using has_vec_len = decltype(std::declval<T&>().vec_len);
|
||||
|
||||
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
|
||||
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
|
||||
template <typename SrcBuffers,
|
||||
@@ -159,94 +156,63 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
is_src_valid);
|
||||
});
|
||||
|
||||
if constexpr(is_detected<has_vec_len, decltype(element_op_)>::value)
|
||||
{
|
||||
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
|
||||
"vec_len in element_op_ type is not index_t");
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
static_assert(elem_op_vec_len == 1 || elem_op_vec_len == 2 ||
|
||||
elem_op_vec_len == 4 || elem_op_vec_len == 8,
|
||||
"vec_len in element_op_ must be 1, 2, 4, 8");
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
|
||||
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
|
||||
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
|
||||
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
using elem_op_vec_t =
|
||||
typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
|
||||
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
using elem_op_vec_t =
|
||||
typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
return src_vectors[iSrc].template AsType<SrcData>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
return dst_vectors(iDst).template AsType<DstData>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
}
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
|
||||
dst_vectors_tuple_(iAccess) = dst_vectors;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user