mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
FP16 data in-register transpose (#41)
* start fixing 16bit data packing * adding StaticTensor * adding StaticTensor * adding StaticTensor * add missing constexpr * adding static tensor * adding static tensor * adding transpose * add inline asm for transpose 2x2 of half_t * add general transpose_vectors(), but have unnecessary register initialization using v_mov * fix unnecessary register initialization in transpose_vector by using more pass-by-reference * add hardcoded logic for NHWC wrw * improve asm for v_pack * make ThreadwiseTensorSliceTransfer_v3r2 support any tensor * tweak * reorganize file
This commit is contained in:
@@ -30,7 +30,8 @@ struct PassThrough
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up)
|
||||
__host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up)
|
||||
{
|
||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
@@ -1708,7 +1709,8 @@ struct Vectorize
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
265
composable_kernel/include/tensor_description/static_tensor.hpp
Normal file
265
composable_kernel/include/tensor_description/static_tensor.hpp
Normal file
@@ -0,0 +1,265 @@
|
||||
#ifndef CK_STATIC_TENSOR_HPP
|
||||
#define CK_STATIC_TENSOR_HPP
|
||||
|
||||
#include "ignore.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// StaticTensor for Scalar
|
||||
template <AddressSpaceEnum_t AddressSpace,
|
||||
typename T,
|
||||
typename TensorDesc,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct StaticTensor
|
||||
{
|
||||
static constexpr auto desc_ = TensorDesc{};
|
||||
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
|
||||
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
|
||||
|
||||
__host__ __device__ constexpr StaticTensor() : invalid_element_value_{0} {}
|
||||
|
||||
__host__ __device__ constexpr StaticTensor(T invalid_element_value)
|
||||
: invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
// read access
|
||||
template <typename Idx,
|
||||
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr const T& operator[](Idx) const
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
return data_[Number<offset>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return T{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return invalid_element_value_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write access
|
||||
template <typename Idx,
|
||||
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr T& operator()(Idx)
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
return data_(Number<offset>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ignore;
|
||||
}
|
||||
}
|
||||
|
||||
StaticBuffer<AddressSpace, T, element_space_size_, true> data_;
|
||||
T invalid_element_value_ = T{0};
|
||||
};
|
||||
|
||||
// StaticTensor for vector
|
||||
template <AddressSpaceEnum_t AddressSpace,
|
||||
typename S,
|
||||
index_t ScalarPerVector,
|
||||
typename TensorDesc,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct StaticTensorTupleOfVectorBuffer
|
||||
{
|
||||
static constexpr auto desc_ = TensorDesc{};
|
||||
static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
|
||||
static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
|
||||
|
||||
static constexpr index_t num_of_vector_ =
|
||||
math::integer_divide_ceil(element_space_size_, ScalarPerVector);
|
||||
|
||||
using V = vector_type<S, ScalarPerVector>;
|
||||
|
||||
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_value_{0} {}
|
||||
|
||||
__host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
|
||||
: invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
// Get S
|
||||
// Idx is for S, not V
|
||||
template <typename Idx,
|
||||
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr const S& operator[](Idx) const
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
return data_[Number<offset>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return S{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return invalid_element_value_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set S
|
||||
// Idx is for S, not V
|
||||
template <typename Idx,
|
||||
typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr S& operator()(Idx)
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
return data_(Number<offset>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ignore;
|
||||
}
|
||||
}
|
||||
|
||||
// Get X
|
||||
// Idx is for S, not X. Idx should be aligned with X
|
||||
template <typename X,
|
||||
typename Idx,
|
||||
typename enable_if<has_same_scalar_type<S, X>::value &&
|
||||
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr X GetAsType(Idx) const
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
return data_.template GetAsType<X>(Number<offset>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
// TODO: is this right way to initialize a vector?
|
||||
return X{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: is this right way to initialize a vector?
|
||||
return X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set X
|
||||
// Idx is for S, not X. Idx should be aligned with X
|
||||
template <typename X,
|
||||
typename Idx,
|
||||
typename enable_if<has_same_scalar_type<S, X>::value &&
|
||||
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr void SetAsType(Idx, X x)
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
data_.template SetAsType<X>(Number<offset>{}, x);
|
||||
}
|
||||
}
|
||||
|
||||
// Get read access to V. No is_valid check
|
||||
// Idx is for S, not V. Idx should be aligned with V
|
||||
template <typename Idx>
|
||||
__host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
return data_.GetVectorTypeReference(Number<offset>{});
|
||||
}
|
||||
|
||||
// Get read access to V. No is_valid check
|
||||
// Idx is for S, not V. Idx should be aligned with V
|
||||
template <typename Idx>
|
||||
__host__ __device__ constexpr V& GetVectorTypeReference(Idx)
|
||||
{
|
||||
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
return data_.GetVectorTypeReference(Number<offset>{});
|
||||
}
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpace, S, num_of_vector_, ScalarPerVector, true> data_;
|
||||
S invalid_element_value_ = S{0};
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t AddressSpace,
|
||||
typename T,
|
||||
typename TensorDesc,
|
||||
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_static_tensor(TensorDesc)
|
||||
{
|
||||
return StaticTensor<AddressSpace, T, TensorDesc, true>{};
|
||||
}
|
||||
|
||||
template <
|
||||
AddressSpaceEnum_t AddressSpace,
|
||||
typename T,
|
||||
typename TensorDesc,
|
||||
typename X,
|
||||
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false,
|
||||
typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value)
|
||||
{
|
||||
return StaticTensor<AddressSpace, T, TensorDesc, true>{invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -151,6 +151,20 @@ struct TensorAdaptor
|
||||
|
||||
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
|
||||
|
||||
#if 0 // debug
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const
|
||||
{
|
||||
// TODO: not implemented
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const
|
||||
{
|
||||
// TODO: not implemented
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
|
||||
@@ -37,7 +37,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferV2<AddressSpaceEnum_t::Vgpr, vector_type<FloatAcc, 16>, MRepeat * NRepeat, true>
|
||||
StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, 16>,
|
||||
MRepeat * NRepeat,
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v3r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -146,22 +146,22 @@ struct BlockwiseTensorSliceTransfer_v4
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,802 @@
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "static_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor
|
||||
template <index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector>
|
||||
struct lambda_scalar_per_access_for_src_and_dst
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
if(i == SrcVectorDim && i == DstVectorDim)
|
||||
{
|
||||
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
|
||||
}
|
||||
else if(i == SrcVectorDim)
|
||||
{
|
||||
return SrcScalarPerVector;
|
||||
}
|
||||
else if(i == DstVectorDim)
|
||||
{
|
||||
return DstScalarPerVector;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
// RunRead(), will be fused with MoveSrcSliceWindow to
|
||||
// save addr computation
|
||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
struct ThreadwiseTensorSliceTransfer_v3r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
using src_vector_t = typename vector_type_maker_t<SrcData, SrcScalarPerVector>::type;
|
||||
|
||||
// copy data from src_buf to src_thread_scratch_
|
||||
src_thread_scratch_.template SetAsType<src_vector_t>(
|
||||
src_data_idx_seq,
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid));
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch()
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) = type_convert<DstData>{}(src_thread_scratch_[idx]);
|
||||
});
|
||||
#else
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
// TODO make this logic more generic for more sub-dword datatype
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
is_same<half_t, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0)
|
||||
{
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
|
||||
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
|
||||
|
||||
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
|
||||
// TODO: make this logic generic for all scenario
|
||||
static_assert(SrcVectorDim != DstVectorDim, "wrong");
|
||||
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstVectorDim,
|
||||
DstScalarPerVector>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
|
||||
constexpr auto data_idx = access_idx * scalar_per_access;
|
||||
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
// TODO type_convert is not used yet!!!!!
|
||||
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
const auto src_vector_refs = generate_tie(
|
||||
[&](auto i) -> const src_vector_t& {
|
||||
// i increment corresponds to movement in DstVectorDim
|
||||
return src_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * dst_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_src_vector>{});
|
||||
|
||||
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
|
||||
auto dst_vector_refs = generate_tie(
|
||||
[&](auto i) -> dst_vector_t& {
|
||||
// i increment corresponds to movement in SrcVectorDim
|
||||
return dst_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * src_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
// TODO type_convert is not used yet!!!!!
|
||||
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) = type_convert<DstData>{}(src_thread_scratch_[idx]);
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstStepHacks>
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
// if there is transpose, it's done here
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch();
|
||||
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths[i] - 1 -
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto dst_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
using dst_vector_t = typename vector_type_maker_t<DstData, DstScalarPerVector>::type;
|
||||
|
||||
// copy data from dst_thread_scratch_ to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq));
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move dst coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
SrcData,
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>
|
||||
src_thread_scratch_;
|
||||
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(dst_thread_scratch_desc_),
|
||||
true>
|
||||
dst_thread_scratch_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -30,7 +30,11 @@
|
||||
#include "amd_address_space.hpp"
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "static_buffer.hpp"
|
||||
// TODO remove this
|
||||
#include "static_buffer_of_vector_type_v2.hpp"
|
||||
#include "dynamic_buffer.hpp"
|
||||
#include "is_known_at_compile_time.hpp"
|
||||
#include "transpose_vectors.hpp"
|
||||
|
||||
#include "inner_product.hpp"
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@
|
||||
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
// experimental implementation
|
||||
// experimental implementation for buffer load/store/atomic
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#endif
|
||||
@@ -89,6 +89,11 @@
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
// experimental implementation for in-regsiter sub-dword transpose
|
||||
#ifndef CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value or void*
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
|
||||
|
||||
@@ -373,19 +373,6 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
__host__ __device__ constexpr auto to_tuple_of_number(const Container&)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<Container>::value, "wrong!");
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Container::At(i);
|
||||
return Number<tmp>{};
|
||||
},
|
||||
Container::Size());
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
|
||||
{
|
||||
|
||||
@@ -58,6 +58,18 @@ __host__ __device__ constexpr auto make_vector_type(Number<N>)
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
// is_scalar_type
|
||||
template <typename TV>
|
||||
struct is_scalar_type
|
||||
{
|
||||
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
|
||||
};
|
||||
|
||||
// has_same_scalar_type
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<Y>>::type>;
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
|
||||
21
composable_kernel/include/utility/ignore.hpp
Normal file
21
composable_kernel/include/utility/ignore.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#ifndef CK_IGNORE_HPP
|
||||
#define CK_IGNORE_HPP
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
struct ignore_t
|
||||
{
|
||||
template <typename T>
|
||||
constexpr void operator=(T&&) const noexcept
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
inline constexpr detail::ignore_t ignore;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,49 @@
|
||||
#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP
|
||||
#define IS_KNOWN_AT_COMPILE_TIME_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
struct is_known_at_compile_time;
|
||||
|
||||
template <>
|
||||
struct is_known_at_compile_time<index_t>
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename T, T X>
|
||||
struct is_known_at_compile_time<integral_constant<T, X>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
struct is_known_at_compile_time<Sequence<Is...>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct is_known_at_compile_time<Tuple<Ts...>>
|
||||
{
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return container_reduce(
|
||||
Tuple<Ts...>{},
|
||||
[](auto x, bool r) {
|
||||
return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
|
||||
},
|
||||
true);
|
||||
}
|
||||
|
||||
static constexpr bool value = IsKnownAtCompileTime();
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,158 +5,156 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
// static buffer for scalar
|
||||
template <AddressSpaceEnum_t AddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
|
||||
: base{}, invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? At(i) : T{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? At(i) : invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
At(i) = x;
|
||||
}
|
||||
return AddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
|
||||
// read access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const T& operator[](Number<I> i) const
|
||||
{
|
||||
return base::operator[](i);
|
||||
}
|
||||
|
||||
// write access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr T& operator()(Number<I> i)
|
||||
{
|
||||
return base::operator()(i);
|
||||
}
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct StaticBufferV2 : public StaticallyIndexedArray<T, N>
|
||||
// static buffer for vector
|
||||
template <AddressSpaceEnum_t AddressSpace,
|
||||
typename S,
|
||||
index_t NumOfVector,
|
||||
index_t ScalarPerVector,
|
||||
bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
|
||||
typename enable_if<is_scalar_type<S>::value, bool>::type = false>
|
||||
struct StaticBufferTupleOfVector
|
||||
: public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
using V = typename vector_type<S, ScalarPerVector>::type;
|
||||
using base = StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>;
|
||||
|
||||
using VecBaseType = typename T::d1_t;
|
||||
static constexpr auto s_per_v = Number<ScalarPerVector>{};
|
||||
static constexpr auto num_of_v_ = Number<NumOfVector>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetVectorSize()
|
||||
{
|
||||
return sizeof(typename T::type) / sizeof(VecBaseType);
|
||||
}
|
||||
|
||||
static constexpr index_t vector_size = GetVectorSize();
|
||||
|
||||
VecBaseType invalid_element_value_ = VecBaseType{0};
|
||||
|
||||
T invalid_vec_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBufferV2() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value)
|
||||
: base{},
|
||||
invalid_vec_value_{invalid_element_value},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
__host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetVector(Number<I> vec_id)
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetElement(Number<I> i, bool)
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
return this->At(vec_id).template AsType<VecBaseType>()(vec_off);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: VecBaseType{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I> i) const
|
||||
{
|
||||
return GetElement(i, true);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
return GetElement(i, true);
|
||||
return AddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
|
||||
// Get S
|
||||
// i is offset of S
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const S& operator[](Number<I> i) const
|
||||
{
|
||||
constexpr auto i_v = i / s_per_v;
|
||||
constexpr auto i_s = i % s_per_v;
|
||||
|
||||
return base::operator[](i_v).template AsType<S>()[i_s];
|
||||
}
|
||||
|
||||
// Set S
|
||||
// i is offset of S
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr S& operator()(Number<I> i)
|
||||
{
|
||||
constexpr auto i_v = i / s_per_v;
|
||||
constexpr auto i_s = i % s_per_v;
|
||||
|
||||
return base::operator()(i_v).template AsType<S>()(i_s);
|
||||
}
|
||||
|
||||
// Get X
|
||||
// i is offset of S, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
index_t I,
|
||||
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
|
||||
__host__ __device__ constexpr auto GetAsType(Number<I> i) const
|
||||
{
|
||||
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
|
||||
|
||||
static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X");
|
||||
static_assert(i % s_per_x == 0, "wrong!");
|
||||
|
||||
constexpr auto i_v = i / s_per_v;
|
||||
constexpr auto i_x = (i % s_per_v) / s_per_x;
|
||||
|
||||
return base::operator[](i_v).template AsType<X>()[i_x];
|
||||
}
|
||||
|
||||
// Set X
|
||||
// i is offset of S, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
index_t I,
|
||||
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
|
||||
__host__ __device__ constexpr void SetAsType(Number<I> i, X x)
|
||||
{
|
||||
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
|
||||
|
||||
static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
|
||||
static_assert(i % s_per_x == 0, "wrong!");
|
||||
|
||||
constexpr auto i_v = i / s_per_v;
|
||||
constexpr auto i_x = (i % s_per_v) / s_per_x;
|
||||
|
||||
base::operator()(i_v).template AsType<X>()(i_x) = x;
|
||||
}
|
||||
|
||||
// Get read access to vector_type V
|
||||
// i is offset of S, not V. i should be aligned to V
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
|
||||
{
|
||||
static_assert(i % s_per_v == 0, "wrong!");
|
||||
|
||||
constexpr auto i_v = i / s_per_v;
|
||||
|
||||
return base::operator[](i_v);
|
||||
}
|
||||
|
||||
// Get write access to vector_type V
|
||||
// i is offset of S, not V. i should be aligned to V
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
|
||||
{
|
||||
static_assert(i % s_per_v == 0, "wrong!");
|
||||
|
||||
constexpr auto i_v = i / s_per_v;
|
||||
|
||||
return base::operator()(i_v);
|
||||
}
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
template <AddressSpaceEnum_t AddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N, true>{};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
|
||||
return StaticBuffer<AddressSpace, T, N, true>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
#ifndef CK_STATIC_BUFFER_OF_VECTOR_TYPE_V2_HPP
|
||||
#define CK_STATIC_BUFFER_OF_VECTOR_TYPE_V2_HPP
|
||||
|
||||
#include "statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
using VecBaseType = typename T::d1_t;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetVectorSize()
|
||||
{
|
||||
return sizeof(typename T::type) / sizeof(VecBaseType);
|
||||
}
|
||||
|
||||
static constexpr index_t vector_size = GetVectorSize();
|
||||
|
||||
VecBaseType invalid_element_value_ = VecBaseType{0};
|
||||
|
||||
T invalid_vec_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBufferOfVectorTypeV2() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBufferOfVectorTypeV2(VecBaseType invalid_element_value)
|
||||
: base{},
|
||||
invalid_vec_value_{invalid_element_value},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetVector(Number<I> vec_id)
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetElement(Number<I> i, bool)
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
return this->At(vec_id).template AsType<VecBaseType>()(vec_off);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: VecBaseType{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I> i) const
|
||||
{
|
||||
return GetElement(i, true);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
return GetElement(i, true);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -8,20 +8,38 @@
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
__host__ __device__ constexpr auto generate_same_type_tuple()
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
|
||||
{
|
||||
return generate_tuple([](auto) -> T { return T{}; }, Number<NSize>{});
|
||||
}
|
||||
using type = Tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
using same_type_tuple = decltype(generate_same_type_tuple<T, NSize>());
|
||||
template <typename T, index_t N>
|
||||
struct StaticallyIndexedArrayImpl
|
||||
{
|
||||
using type =
|
||||
typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
|
||||
typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct StaticallyIndexedArrayImpl<T, 0>
|
||||
{
|
||||
using type = Tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct StaticallyIndexedArrayImpl<T, 1>
|
||||
{
|
||||
using type = Tuple<T>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
using StaticallyIndexedArray = detail::same_type_tuple<T, NSize>;
|
||||
template <typename T, index_t N>
|
||||
using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type;
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
|
||||
87
composable_kernel/include/utility/transpose_vectors.hpp
Normal file
87
composable_kernel/include/utility/transpose_vectors.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
#ifndef CK_TRANSPOSE_VECTORS_AMD_HPP
|
||||
#define CK_TRANSPOSE_VECTORS_AMD_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename S,
|
||||
index_t NX,
|
||||
index_t NY,
|
||||
typename enable_if<is_scalar_type<S>::value, bool>::type = false>
|
||||
struct transpose_vectors;
|
||||
|
||||
// transpose fp16 2x2
|
||||
__device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1)
|
||||
{
|
||||
#if 0
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
|
||||
vector_type<half_t, 2> vy0, vy1;
|
||||
|
||||
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
|
||||
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
|
||||
|
||||
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
|
||||
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
|
||||
|
||||
y0 = vy0.template AsType<half2_t>()[I0];
|
||||
y1 = vy1.template AsType<half2_t>()[I0];
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_pack_b32_f16 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(y0)
|
||||
: "v"(x0), "v"(x1));
|
||||
|
||||
asm volatile("\n \
|
||||
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1] \n \
|
||||
"
|
||||
: "=v"(y1)
|
||||
: "v"(x0), "v"(x1));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t NX, index_t NY>
|
||||
struct transpose_vectors<half_t, NX, NY>
|
||||
{
|
||||
// we got [NY * NX] ammount of S data to be transposed
|
||||
static constexpr index_t s_per_x = NY;
|
||||
static constexpr index_t s_per_y = NX;
|
||||
|
||||
using S = half_t;
|
||||
using VX = vector_type<half_t, s_per_x>;
|
||||
using VY = vector_type<half_t, s_per_y>;
|
||||
|
||||
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
|
||||
StaticallyIndexedArray<VY&, NY>& vy_tuple)
|
||||
{
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
|
||||
|
||||
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 2>{}([&](auto iy) {
|
||||
static_for<0, NX, 2>{}([&](auto ix) {
|
||||
// reference to 2 half2_t data from vx_tuple
|
||||
const auto& x_s2_0 = vx_tuple[ix].template AsType<half2_t>()[iy / I2];
|
||||
const auto& x_s2_1 = vx_tuple[ix + I1].template AsType<half2_t>()[iy / I2];
|
||||
|
||||
// reference to 2 half2_t data from vy_tuple
|
||||
auto& y_s2_0 = vy_tuple(iy).template AsType<half2_t>()(ix / I2);
|
||||
auto& y_s2_1 = vy_tuple(iy + I1).template AsType<half2_t>()(ix / I2);
|
||||
|
||||
// transpose
|
||||
transpose_fp16_2x2(x_s2_0, x_s2_1, y_s2_0, y_s2_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -117,6 +117,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
|
||||
|
||||
// read access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& At(Number<I>) const
|
||||
{
|
||||
@@ -124,6 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
|
||||
// write access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& At(Number<I>)
|
||||
{
|
||||
@@ -131,12 +133,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
|
||||
// read access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
// write access
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
@@ -162,5 +166,12 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
||||
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/tie
|
||||
template <typename... Args>
|
||||
constexpr Tuple<Args&...> tie(Args&... args) noexcept
|
||||
{
|
||||
return {args...};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -6,22 +6,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename... Ts>
|
||||
struct is_known_at_compile_time<Tuple<Ts...>>
|
||||
{
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return container_reduce(
|
||||
Tuple<Ts...>{},
|
||||
[](auto x, bool r) {
|
||||
return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
|
||||
},
|
||||
true);
|
||||
}
|
||||
|
||||
static constexpr bool value = IsKnownAtCompileTime();
|
||||
};
|
||||
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
|
||||
{
|
||||
@@ -29,6 +13,13 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... xs) { return tie(f(xs)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
|
||||
@@ -31,21 +31,6 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
template <typename T>
|
||||
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
|
||||
template <typename T>
|
||||
struct is_known_at_compile_time;
|
||||
|
||||
template <>
|
||||
struct is_known_at_compile_time<index_t>
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename T, T X>
|
||||
struct is_known_at_compile_time<integral_constant<T, X>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y as_type(X x)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user