mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Refactor threadwise copy using sfcurve (#101)
* add space_filling_curve
* cleanup and move space_filling_curve into test
* WIP: start refactoring threadwise_transfer_v1r3
* threadwise_copy works but needs further refactoring
* add some comments
* add SpaceFillingCurve::GetIndices()
* minor changes
* removed GetIndices; refactored GetDstCoordinateResetStep
* add DynamicBuffer::Transfer, but Add is not tested
* rebased agaist develop
* threadwise_copy_v6r1/v6r2/v6r3 using space-filling curve start to work
* minor changes
* refactored threadcopy v3r1, v2; removed old implementations
* clang-format
* cleanup
* fix a typo in v6r3
* format
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 0619ebf70b]
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -67,8 +68,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
@@ -85,16 +84,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstStepHacks>
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstStepHacks& dst_step_hacks)
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -108,9 +103,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
@@ -119,85 +111,26 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_access_idx[i]
|
||||
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData dst_v;
|
||||
|
||||
@@ -212,69 +145,18 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
if constexpr(idx_1d.value != num_accesses - 1)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Add)
|
||||
{
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type tmp;
|
||||
tmp.template AsType<dst_vector_t>()(Number<0>{}) =
|
||||
dst_buf.template Get<dst_vector_t>(dst_coord_.GetOffset(), is_dst_valid);
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto t) {
|
||||
dst_vector.template AsType<DstData>()(t) += tmp.template AsType<DstData>()[t];
|
||||
});
|
||||
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
@@ -287,82 +169,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_dst = remove_cvref_t<DstDesc>::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, Number<0>{});
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in Run(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
return reset_step;
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
@@ -383,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
}; // namespace ck
|
||||
}; // struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
@@ -428,16 +248,12 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcStepHacks>
|
||||
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf,
|
||||
const SrcStepHacks& src_step_hacks)
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
@@ -453,9 +269,6 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
@@ -464,80 +277,19 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
constexpr auto src_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_access_idx[i]
|
||||
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
@@ -555,38 +307,13 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
dst_buf(Number<dst_offset>{}) = src_vector.template AsType<SrcData>()[i];
|
||||
});
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
if constexpr(idx_1d.value != num_accesses - 1)
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
@@ -599,82 +326,20 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, Number<0>{});
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in Run(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
return reset_step;
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "static_tensor.hpp"
|
||||
#include "tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -123,73 +124,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
// loop over space-filling curve
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
@@ -218,39 +162,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
.template SetAsType<src_vector_t>(
|
||||
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
// move coordinate
|
||||
if constexpr(idx_1d.value != num_accesses - 1)
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
|
||||
}
|
||||
();
|
||||
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
@@ -374,73 +292,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DstDimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths[i] - 1 -
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto dst_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
constexpr auto dst_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
|
||||
@@ -470,39 +330,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
// move coordinate
|
||||
if constexpr(idx_1d.value != num_accesses - 1)
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
();
|
||||
|
||||
// move dst coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
@@ -522,55 +356,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, Number<0>{});
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
return reset_step;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
@@ -580,55 +374,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DstDimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, Number<0>{});
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
return reset_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -40,9 +41,6 @@ struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
|
||||
@@ -79,70 +77,14 @@ struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
auto make_forward_steps = [&](auto desc) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
};
|
||||
|
||||
auto make_backward_steps = [&](auto desc) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
};
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = make_forward_steps(src_desc);
|
||||
const auto dst_forward_steps = make_forward_steps(dst_desc);
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = make_backward_steps(src_desc);
|
||||
const auto dst_backward_steps = make_backward_steps(dst_desc);
|
||||
|
||||
// loop over slice window
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
// loop over space-filling curve
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
|
||||
using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
@@ -168,59 +110,20 @@ struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// move coordinate
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
if constexpr(idx_1d.value != num_accesses - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move coordinate back to slice origin (or not)
|
||||
@@ -243,59 +146,18 @@ struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
|
||||
__device__ static constexpr auto GetCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, Number<0>{});
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate data index after last iteration in Run(), if it has not being reset
|
||||
constexpr auto data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_data_step = [&]() {
|
||||
Index reset_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; });
|
||||
|
||||
return reset_data_step_;
|
||||
}();
|
||||
|
||||
return reset_data_step;
|
||||
return reset_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
@@ -332,7 +194,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
};
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -44,10 +45,6 @@ struct ThreadwiseTensorSliceTransfer_v6r2
|
||||
using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using Src0CoordStep = decltype(make_tensor_coordinate_step(Src0Desc{}, Index{}));
|
||||
using Src1CoordStep = decltype(make_tensor_coordinate_step(Src1Desc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
|
||||
@@ -96,72 +93,14 @@ struct ThreadwiseTensorSliceTransfer_v6r2
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
auto make_forward_steps = [&](auto desc) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
};
|
||||
|
||||
auto make_backward_steps = [&](auto desc) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
};
|
||||
|
||||
// make forward steps
|
||||
const auto src0_forward_steps = make_forward_steps(src0_desc);
|
||||
const auto src1_forward_steps = make_forward_steps(src1_desc);
|
||||
const auto dst_forward_steps = make_forward_steps(dst_desc);
|
||||
|
||||
// make backward steps
|
||||
const auto src0_backward_steps = make_backward_steps(src0_desc);
|
||||
const auto src1_backward_steps = make_backward_steps(src1_desc);
|
||||
const auto dst_backward_steps = make_backward_steps(dst_desc);
|
||||
|
||||
// loop over slice window
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
|
||||
using src0_vector_type = vector_type_maker_t<Src0Data, ScalarPerVector>;
|
||||
using src0_vector_t = typename src0_vector_type::type;
|
||||
|
||||
@@ -197,65 +136,22 @@ struct ThreadwiseTensorSliceTransfer_v6r2
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// move coordinate
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src0_desc, src0_coord_, src0_forward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src1_desc, src1_coord_, src1_forward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src0_desc, src0_coord_, src0_backward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src1_desc, src1_coord_, src1_backward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
if constexpr(idx_1d.value != num_accesses - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move coordinate back to slice origin (or not)
|
||||
@@ -286,59 +182,18 @@ struct ThreadwiseTensorSliceTransfer_v6r2
|
||||
|
||||
__device__ static constexpr auto GetCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, Number<0>{});
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate data index after last iteration in Run(), if it has not being reset
|
||||
constexpr auto data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_data_step = [&]() {
|
||||
Index reset_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; });
|
||||
|
||||
return reset_data_step_;
|
||||
}();
|
||||
|
||||
return reset_data_step;
|
||||
return reset_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -48,11 +49,6 @@ struct ThreadwiseTensorSliceTransfer_v6r3
|
||||
using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using Src0CoordStep = decltype(make_tensor_coordinate_step(Src0Desc{}, Index{}));
|
||||
using Src1CoordStep = decltype(make_tensor_coordinate_step(Src1Desc{}, Index{}));
|
||||
using Src2CoordStep = decltype(make_tensor_coordinate_step(Src2Desc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
|
||||
@@ -112,74 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v6r3
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
auto make_forward_steps = [&](auto desc) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
};
|
||||
|
||||
auto make_backward_steps = [&](auto desc) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
};
|
||||
|
||||
// make forward steps
|
||||
const auto src0_forward_steps = make_forward_steps(src0_desc);
|
||||
const auto src1_forward_steps = make_forward_steps(src1_desc);
|
||||
const auto src2_forward_steps = make_forward_steps(src2_desc);
|
||||
const auto dst_forward_steps = make_forward_steps(dst_desc);
|
||||
|
||||
// make backward steps
|
||||
const auto src0_backward_steps = make_backward_steps(src0_desc);
|
||||
const auto src1_backward_steps = make_backward_steps(src1_desc);
|
||||
const auto src2_backward_steps = make_backward_steps(src2_desc);
|
||||
const auto dst_backward_steps = make_backward_steps(dst_desc);
|
||||
|
||||
// loop over slice window
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, num_accesses, 1>{}([&](auto idx_1d) {
|
||||
using src0_vector_type = vector_type_maker_t<Src0Data, ScalarPerVector>;
|
||||
using src0_vector_t = typename src0_vector_type::type;
|
||||
|
||||
@@ -224,72 +160,24 @@ struct ThreadwiseTensorSliceTransfer_v6r3
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// move coordinate
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src0_desc, src0_coord_, src0_forward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src1_desc, src1_coord_, src1_forward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src2_desc, src2_coord_, src2_forward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src0_desc, src0_coord_, src0_backward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src1_desc, src1_coord_, src1_backward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src2_desc, src2_coord_, src2_backward_steps[dim_access_order[i]]);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
if constexpr(idx_1d.value != num_accesses - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
src2_desc, src2_coord_, make_tensor_coordinate_step(src2_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move coordinate back to slice origin (or not)
|
||||
@@ -328,59 +216,18 @@ struct ThreadwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ static constexpr auto GetCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
constexpr auto num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_accesses - 1>{}, Number<0>{});
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate data index after last iteration in Run(), if it has not being reset
|
||||
constexpr auto data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_data_step = [&]() {
|
||||
Index reset_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; });
|
||||
|
||||
return reset_data_step_;
|
||||
}();
|
||||
|
||||
return reset_data_step;
|
||||
return reset_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "config.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -108,6 +109,30 @@ struct DynamicBuffer
|
||||
}
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum_t Op,
|
||||
typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
this->template Set<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
this->template AtomicAdd<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum_t::Add)
|
||||
{
|
||||
auto tmp = this->template Get<X>(i, is_valid_element);
|
||||
this->template Set<X>(i, is_valid_element, x + tmp);
|
||||
// tmp += x;
|
||||
// this->template Set<X>(i, is_valid_element, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
#ifndef TENSOR_SPACE_FILLING_CURVE_HPP
|
||||
#define TENSOR_SPACE_FILLING_CURVE_HPP
|
||||
|
||||
#include "math.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "sequence_helper.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "statically_indexed_array_multi_index.hpp"
|
||||
#include "tuple_helper.hpp"
|
||||
@@ -37,13 +41,25 @@ struct SpaceFillingCurve
|
||||
ScalarPerVector;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1dBegin, index_t AccessIdx1dEnd>
|
||||
static __device__ __host__ constexpr auto GetStepBetween(Number<AccessIdx1dBegin>,
|
||||
Number<AccessIdx1dEnd>)
|
||||
{
|
||||
static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative");
|
||||
static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0");
|
||||
static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative");
|
||||
static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0");
|
||||
|
||||
constexpr auto idx_begin = GetIndex(Number<AccessIdx1dBegin>{});
|
||||
constexpr auto idx_end = GetIndex(Number<AccessIdx1dEnd>{});
|
||||
return idx_end - idx_begin;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static __device__ __host__ constexpr auto GetForwardStep(Number<AccessIdx1d>)
|
||||
{
|
||||
|
||||
constexpr auto idx_curr = GetIndex(Number<AccessIdx1d>{});
|
||||
constexpr auto idx_next = GetIndex(Number<AccessIdx1d + 1>{});
|
||||
return idx_next - idx_curr;
|
||||
static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0");
|
||||
return GetStepBetween(Number<AccessIdx1d>{}, Number<AccessIdx1d + 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
@@ -51,9 +67,7 @@ struct SpaceFillingCurve
|
||||
{
|
||||
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
|
||||
|
||||
constexpr auto idx_curr = GetIndex(Number<AccessIdx1d>{});
|
||||
constexpr auto idx_prev = GetIndex(Number<AccessIdx1d - 1>{});
|
||||
return idx_prev - idx_curr;
|
||||
return GetStepBetween(Number<AccessIdx1d>{}, Number<AccessIdx1d - 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
@@ -129,3 +143,4 @@ struct SpaceFillingCurve
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -41,8 +41,7 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding;
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
|
||||
|
||||
// clang-format off
|
||||
#if 0
|
||||
|
||||
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
|
||||
if(argc == 1)
|
||||
{
|
||||
init_method = 1;
|
||||
data_type = 0;
|
||||
data_type = 0;
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
|
||||
@@ -161,12 +161,11 @@ int main(int, char*[])
|
||||
if(pass)
|
||||
{
|
||||
std::cout << "test magic number division: Pass" << std::endl;
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test magic number division: Fail" << std::endl;
|
||||
return -1;
|
||||
return -1;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -29,9 +29,9 @@ void traverse_using_space_filling_curve()
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using TensorLengths = Sequence<4, 10, 9>;
|
||||
using TensorLengths = Sequence<16, 10, 9>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<1, 2, 3>;
|
||||
using ScalarsPerAccess = Sequence<4, 2, 3>;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;
|
||||
|
||||
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
|
||||
@@ -39,36 +39,36 @@ void traverse_using_space_filling_curve()
|
||||
make_tuple(0, 4, 0),
|
||||
make_tuple(0, 6, 0),
|
||||
make_tuple(0, 8, 0),
|
||||
make_tuple(1, 8, 0),
|
||||
make_tuple(1, 6, 0),
|
||||
make_tuple(1, 4, 0),
|
||||
make_tuple(1, 2, 0),
|
||||
make_tuple(1, 0, 0),
|
||||
make_tuple(2, 0, 0),
|
||||
make_tuple(2, 2, 0),
|
||||
make_tuple(2, 4, 0),
|
||||
make_tuple(2, 6, 0),
|
||||
make_tuple(2, 8, 0),
|
||||
make_tuple(3, 8, 0),
|
||||
make_tuple(3, 6, 0),
|
||||
make_tuple(3, 4, 0),
|
||||
make_tuple(3, 2, 0),
|
||||
make_tuple(3, 0, 0),
|
||||
make_tuple(3, 0, 3),
|
||||
make_tuple(3, 2, 3),
|
||||
make_tuple(3, 4, 3),
|
||||
make_tuple(3, 6, 3),
|
||||
make_tuple(3, 8, 3),
|
||||
make_tuple(2, 8, 3),
|
||||
make_tuple(2, 6, 3),
|
||||
make_tuple(2, 4, 3),
|
||||
make_tuple(2, 2, 3),
|
||||
make_tuple(2, 0, 3),
|
||||
make_tuple(1, 0, 3),
|
||||
make_tuple(1, 2, 3),
|
||||
make_tuple(1, 4, 3),
|
||||
make_tuple(1, 6, 3),
|
||||
make_tuple(1, 8, 3),
|
||||
make_tuple(4, 8, 0),
|
||||
make_tuple(4, 6, 0),
|
||||
make_tuple(4, 4, 0),
|
||||
make_tuple(4, 2, 0),
|
||||
make_tuple(4, 0, 0),
|
||||
make_tuple(8, 0, 0),
|
||||
make_tuple(8, 2, 0),
|
||||
make_tuple(8, 4, 0),
|
||||
make_tuple(8, 6, 0),
|
||||
make_tuple(8, 8, 0),
|
||||
make_tuple(12, 8, 0),
|
||||
make_tuple(12, 6, 0),
|
||||
make_tuple(12, 4, 0),
|
||||
make_tuple(12, 2, 0),
|
||||
make_tuple(12, 0, 0),
|
||||
make_tuple(12, 0, 3),
|
||||
make_tuple(12, 2, 3),
|
||||
make_tuple(12, 4, 3),
|
||||
make_tuple(12, 6, 3),
|
||||
make_tuple(12, 8, 3),
|
||||
make_tuple(8, 8, 3),
|
||||
make_tuple(8, 6, 3),
|
||||
make_tuple(8, 4, 3),
|
||||
make_tuple(8, 2, 3),
|
||||
make_tuple(8, 0, 3),
|
||||
make_tuple(4, 0, 3),
|
||||
make_tuple(4, 2, 3),
|
||||
make_tuple(4, 4, 3),
|
||||
make_tuple(4, 6, 3),
|
||||
make_tuple(4, 8, 3),
|
||||
make_tuple(0, 8, 3),
|
||||
make_tuple(0, 6, 3),
|
||||
make_tuple(0, 4, 3),
|
||||
@@ -79,21 +79,21 @@ void traverse_using_space_filling_curve()
|
||||
make_tuple(0, 4, 6),
|
||||
make_tuple(0, 6, 6),
|
||||
make_tuple(0, 8, 6),
|
||||
make_tuple(1, 8, 6),
|
||||
make_tuple(1, 6, 6),
|
||||
make_tuple(1, 4, 6),
|
||||
make_tuple(1, 2, 6),
|
||||
make_tuple(1, 0, 6),
|
||||
make_tuple(2, 0, 6),
|
||||
make_tuple(2, 2, 6),
|
||||
make_tuple(2, 4, 6),
|
||||
make_tuple(2, 6, 6),
|
||||
make_tuple(2, 8, 6),
|
||||
make_tuple(3, 8, 6),
|
||||
make_tuple(3, 6, 6),
|
||||
make_tuple(3, 4, 6),
|
||||
make_tuple(3, 2, 6),
|
||||
make_tuple(3, 0, 6));
|
||||
make_tuple(4, 8, 6),
|
||||
make_tuple(4, 6, 6),
|
||||
make_tuple(4, 4, 6),
|
||||
make_tuple(4, 2, 6),
|
||||
make_tuple(4, 0, 6),
|
||||
make_tuple(8, 0, 6),
|
||||
make_tuple(8, 2, 6),
|
||||
make_tuple(8, 4, 6),
|
||||
make_tuple(8, 6, 6),
|
||||
make_tuple(8, 8, 6),
|
||||
make_tuple(12, 8, 6),
|
||||
make_tuple(12, 6, 6),
|
||||
make_tuple(12, 4, 6),
|
||||
make_tuple(12, 2, 6),
|
||||
make_tuple(12, 0, 6));
|
||||
|
||||
constexpr index_t num_accesses = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
|
||||
@@ -69,7 +69,6 @@ struct gemmArgs
|
||||
int KBatch;
|
||||
};
|
||||
|
||||
|
||||
int test_gemm(const gemmArgs& args)
|
||||
{
|
||||
bool a_row_major, b_row_major, c_row_major;
|
||||
@@ -115,8 +114,10 @@ int test_gemm(const gemmArgs& args)
|
||||
|
||||
Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
|
||||
Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
|
||||
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_device_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
|
||||
// init data
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
@@ -205,7 +206,7 @@ int test_gemm(const gemmArgs& args)
|
||||
else
|
||||
{
|
||||
std::cout << "test split k: Fail " << std::endl;
|
||||
error_code = -1; // test needs to report failure
|
||||
error_code = -1; // test needs to report failure
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
@@ -221,17 +222,17 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
const int K = std::stoi(argv[4]);
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
const int K = std::stoi(argv[4]);
|
||||
|
||||
const int StrideA = std::stoi(argv[5]);
|
||||
const int StrideB = std::stoi(argv[6]);
|
||||
const int StrideC = std::stoi(argv[7]);
|
||||
const int KBatch = std::stoi(argv[8]);
|
||||
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
|
||||
const int StrideA = std::stoi(argv[5]);
|
||||
const int StrideB = std::stoi(argv[6]);
|
||||
const int StrideC = std::stoi(argv[7]);
|
||||
const int KBatch = std::stoi(argv[8]);
|
||||
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -242,12 +243,11 @@ int main(int argc, char* argv[])
|
||||
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
|
||||
return -1;
|
||||
}
|
||||
for(const auto& kinder: test_cases)
|
||||
for(const auto& kinder : test_cases)
|
||||
{
|
||||
const auto res = test_gemm(kinder);
|
||||
if(!res)
|
||||
return -1;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user