mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Reorganize project folders (#6)
This commit is contained in:
18
include/ck_tile/core/README.md
Normal file
18
include/ck_tile/core/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# ck_tile/core #
|
||||
|
||||
`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...)
|
||||
|
||||
```
|
||||
algorithm/
|
||||
coordinate transform and some other reusable algorithm
|
||||
arch/
|
||||
contains some basic device building block like mma, buffer addressing, etc...
|
||||
container/
|
||||
contains basic container data structure, array/sequence/tuple/...
|
||||
numeric/
|
||||
data type, and data type related math
|
||||
tensor/
|
||||
tensor descriptors and tile level API
|
||||
utility/
|
||||
other utility function for both host/device
|
||||
```
|
||||
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
38
include/ck_tile/core/algorithm/cluster_descriptor.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor(
|
||||
const Lengths& lengths,
|
||||
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{})
|
||||
{
|
||||
constexpr index_t ndim_low = Lengths::size();
|
||||
|
||||
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
|
||||
|
||||
const auto low_lengths = generate_tuple(
|
||||
[&](auto idim_low) { return reordered_lengths[idim_low]; }, number<ndim_low>{});
|
||||
|
||||
const auto transform = make_merge_transform(low_lengths);
|
||||
|
||||
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
|
||||
|
||||
constexpr auto up_dim_new_top_ids = sequence<0>{};
|
||||
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1752
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
1752
include/ck_tile/core/algorithm/coordinate_transform.hpp
Normal file
File diff suppressed because it is too large
Load Diff
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// pre-defined indexing adaptor used for indexing(scatter/gather)
|
||||
|
||||
// this version cache the index inside thread register(which is also prefered in real senario)
|
||||
// however it's user's responsibility that each thread only provide one indexing, which means
|
||||
// move coordinate will not change on this dim
|
||||
template <typename IndexingType>
|
||||
struct indexing_adaptor_onshot_cached
|
||||
{
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
|
||||
: cached_idx_(idx)
|
||||
{
|
||||
}
|
||||
IndexingType cached_idx_;
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = cached_idx_;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& /*idx_low*/,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
|
||||
|
||||
// pass the diff to lower, but not changing the actually index
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
168
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
168
include/ck_tile/core/algorithm/space_filling_curve.hpp
Normal file
@@ -0,0 +1,168 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TensorLengths,
|
||||
typename DimAccessOrder,
|
||||
typename ScalarsPerAccess,
|
||||
bool SnakeCurved = true> // # of scalars per access in each dimension
|
||||
struct space_filling_curve
|
||||
{
|
||||
static constexpr index_t TensorSize =
|
||||
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
|
||||
static_assert(0 < TensorSize,
|
||||
"space_filling_curve should be used to access a non-empty tensor");
|
||||
|
||||
static constexpr index_t nDim = TensorLengths::size();
|
||||
|
||||
using Index = multi_index<nDim>;
|
||||
|
||||
static constexpr index_t ScalarPerVector =
|
||||
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
|
||||
|
||||
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
|
||||
static constexpr auto dim_access_order = DimAccessOrder{};
|
||||
static constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(ordered_access_lengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
|
||||
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
|
||||
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
|
||||
|
||||
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
|
||||
number<AccessIdx1dTail>)
|
||||
{
|
||||
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
|
||||
"1D index out of range");
|
||||
|
||||
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
|
||||
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
|
||||
return idx_tail - idx_head;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
|
||||
{
|
||||
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
|
||||
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
|
||||
}
|
||||
|
||||
// Do not use this function directly!
|
||||
// TODO: can refactor into generic lambda in the future
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
|
||||
{
|
||||
#if 0
|
||||
/*
|
||||
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
|
||||
*/
|
||||
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
|
||||
#else
|
||||
|
||||
constexpr auto access_strides =
|
||||
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
|
||||
|
||||
constexpr auto idx_1d = number<AccessIdx1d>{};
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
|
||||
{
|
||||
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
|
||||
{
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
|
||||
id = res / access_strides[kdim].value;
|
||||
res -= id * access_strides[kdim].value;
|
||||
});
|
||||
|
||||
return id;
|
||||
};
|
||||
|
||||
constexpr auto id = compute_index_impl(idim);
|
||||
return number<id>{};
|
||||
};
|
||||
|
||||
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
|
||||
#endif
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
statically_indexed_array<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto idim) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, idim, 1>{}(
|
||||
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
|
||||
|
||||
forward_sweep_(idim) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate multi-dim tensor index
|
||||
auto idx_md = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
ordered_idx(idim) =
|
||||
!SnakeCurved || forward_sweep[idim]
|
||||
? ordered_access_idx[idim]
|
||||
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
ScalarsPerAccess{};
|
||||
}();
|
||||
return idx_md;
|
||||
}
|
||||
|
||||
// FIXME: return tuple of number<>, which is compile time only variable
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
|
||||
{
|
||||
constexpr auto idx = _get_index(number<AccessIdx1d>{});
|
||||
|
||||
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
213
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
Normal file
213
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
Normal file
@@ -0,0 +1,213 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Enumeration describing static tile distribution patterns.
|
||||
*
|
||||
*/
|
||||
enum struct tile_distribution_pattern
|
||||
{
|
||||
/**
|
||||
* @brief Thread raked pattern.
|
||||
*
|
||||
*/
|
||||
thread_raked,
|
||||
/**
|
||||
* @brief Warp raked pattern.
|
||||
*
|
||||
*/
|
||||
warp_raked,
|
||||
/**
|
||||
* @brief Block raked pattern - aka linear.
|
||||
*
|
||||
*/
|
||||
block_raked,
|
||||
};
|
||||
|
||||
struct TileDistributionEncodingPattern
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Class creating 2D static tile distribution with different load/store patterns.
|
||||
*
|
||||
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
|
||||
* is contiguous and we can do vector load on this dimension.
|
||||
*
|
||||
* @tparam BlockSize Number of threads in a workgroup.
|
||||
* @tparam YPerTile The tile size of outer/leftmost dimension.
|
||||
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
|
||||
* @tparam VecSize The vector access size.
|
||||
* @tparam DistributionPattern The enumeration describing used access pattern.
|
||||
*/
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern>
|
||||
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
|
||||
{
|
||||
};
|
||||
|
||||
// Thread raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
// # of rows in Y dim accessed by single wavefront in one iteration
|
||||
static constexpr index_t Y1 = warp_size / X0;
|
||||
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
|
||||
|
||||
static constexpr index_t Y0 = num_warps;
|
||||
// YPerWarp = YPerTile / Y0;
|
||||
// Y2 = YPerWarp / Y1;
|
||||
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
|
||||
|
||||
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
};
|
||||
|
||||
// Warp raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
|
||||
static constexpr index_t Y0 = num_warps;
|
||||
static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
|
||||
|
||||
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
// Block raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
static constexpr index_t Y1 = num_warps;
|
||||
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
|
||||
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
2691
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
2691
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2559
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
2559
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
File diff suppressed because it is too large
Load Diff
157
include/ck_tile/core/arch/arch.hpp
Normal file
157
include/ck_tile/core/arch/arch.hpp
Normal file
@@ -0,0 +1,157 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename, bool>
|
||||
struct safe_underlying_type;
|
||||
|
||||
template <typename T>
|
||||
struct safe_underlying_type<T, true>
|
||||
{
|
||||
using type = std::underlying_type_t<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct safe_underlying_type<T, false>
|
||||
{
|
||||
using type = void;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;
|
||||
|
||||
enum struct address_space_enum : std::uint16_t
|
||||
{
|
||||
generic = 0,
|
||||
global,
|
||||
lds,
|
||||
sgpr,
|
||||
constant,
|
||||
vgpr
|
||||
};
|
||||
|
||||
enum struct memory_operation_enum : std::uint16_t
|
||||
{
|
||||
set = 0,
|
||||
atomic_add,
|
||||
atomic_max,
|
||||
add
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
// warpSize is defined by HIP
|
||||
return warpSize;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
|
||||
|
||||
// TODO: deprecate these
|
||||
CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
// Use these instead
|
||||
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("s_wait_loadcnt %0 \n"
|
||||
"s_barrier_signal -1 \n"
|
||||
"s_barrier_wait -1"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#else
|
||||
asm volatile("s_waitcnt vmcnt(%0) \n"
|
||||
"s_barrier"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
|
||||
{
|
||||
#if 1
|
||||
asm volatile("s_nop %0" : : "n"(cnt) :);
|
||||
#else
|
||||
__builtin_amdgcn_sched_barrier(cnt);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define CK_CONSTANT_ADDRESS_SPACE \
|
||||
__attribute__((address_space( \
|
||||
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
|
||||
|
||||
template <typename T>
|
||||
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
|
||||
{
|
||||
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T*)(p); // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
|
||||
{
|
||||
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
|
||||
// only c-style pointer cast seems be able to be compiled;
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
458
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
458
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
@@ -0,0 +1,458 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, typename ComputeType>
|
||||
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
|
||||
{
|
||||
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
|
||||
{
|
||||
bf16x2_t rtn;
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
|
||||
{
|
||||
bf16x4_t rtn;
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf16_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf16_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
|
||||
{
|
||||
fp8x4_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
|
||||
{
|
||||
fp8x8_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<fp8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<fp8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<fp8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<fp8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
|
||||
{
|
||||
bf8x4_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
|
||||
{
|
||||
bf8x8_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<bf8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<bf8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<bf8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<bf8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
|
||||
// each datatype.
|
||||
template <typename X>
|
||||
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
{
|
||||
union U32BF162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf16x2_t* bf162_a;
|
||||
};
|
||||
|
||||
union U32BF162
|
||||
{
|
||||
uint32_t u32;
|
||||
bf16x2_t bf162;
|
||||
};
|
||||
|
||||
U32BF162_ADDR dword_addr;
|
||||
U32BF162 cur_v;
|
||||
U32BF162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.bf162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
|
||||
{
|
||||
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
|
||||
union U64BF164_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf16x4_t* bf164_a;
|
||||
};
|
||||
|
||||
// Union to treat the data as either bf16x4_t or 64-bit integer
|
||||
union U64BF164
|
||||
{
|
||||
uint64_t u64;
|
||||
bf16x4_t bf164;
|
||||
};
|
||||
|
||||
U64BF164_ADDR addr;
|
||||
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
|
||||
|
||||
// First read (non-atomic) of the old value
|
||||
U64BF164 cur_v;
|
||||
cur_v.u64 = *addr.u64_a;
|
||||
|
||||
U64BF164 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
do
|
||||
{
|
||||
// old 64 bits
|
||||
old_v = cur_v.u64;
|
||||
|
||||
// Add elementwise in bf16
|
||||
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt the 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
|
||||
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
|
||||
{
|
||||
union U32FP84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp8x4_t* fp84_a;
|
||||
};
|
||||
|
||||
union U32FP84
|
||||
{
|
||||
uint32_t u32;
|
||||
fp8x4_t fp84;
|
||||
};
|
||||
|
||||
U32FP84_ADDR dword_addr;
|
||||
U32FP84 cur_v;
|
||||
U32FP84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.fp84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
|
||||
{
|
||||
union U32BF84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf8x4_t* bf84_a;
|
||||
};
|
||||
|
||||
union U32BF84
|
||||
{
|
||||
uint32_t u32;
|
||||
bf8x4_t bf84;
|
||||
};
|
||||
|
||||
U32BF84_ADDR dword_addr;
|
||||
U32BF84 cur_v;
|
||||
U32BF84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.bf84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
|
||||
{
|
||||
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
|
||||
union U64FP88_ADDR
|
||||
{
|
||||
uint64_t* u64_a; // pointer to 64-bit integer
|
||||
fp8x8_t* fp88_a; // pointer to fp8x8_t
|
||||
};
|
||||
|
||||
union U64FP88
|
||||
{
|
||||
uint64_t u64;
|
||||
fp8x8_t fp88;
|
||||
};
|
||||
|
||||
U64FP88_ADDR dword_addr;
|
||||
U64FP88 cur_v;
|
||||
U64FP88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
// Point to the destination as both fp8x8_t* and uint64_t*.
|
||||
dword_addr.fp88_a = p_dst;
|
||||
// Initial read of 64 bits from memory
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each fp8 element using your add_fp8x8_t(...) routine
|
||||
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for bf8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
|
||||
{
|
||||
union U64BF88_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf8x8_t* bf88_a;
|
||||
};
|
||||
|
||||
union U64BF88
|
||||
{
|
||||
uint64_t u64;
|
||||
bf8x8_t bf88;
|
||||
};
|
||||
|
||||
U64BF88_ADDR dword_addr;
|
||||
U64BF88 cur_v;
|
||||
U64BF88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
dword_addr.bf88_a = p_dst;
|
||||
// Read the original 64 bits
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each bf8 element using your add_bf8x8_t(...) routine
|
||||
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// 64-bit CAS loop
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
|
||||
"The granularity of the thread buffer is unsupported on the hardware!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return atomicAdd(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
|
||||
x.template get_as<bf16x4_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
129
include/ck_tile/core/arch/utility.hpp
Normal file
129
include/ck_tile/core/arch/utility.hpp
Normal file
@@ -0,0 +1,129 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: we have "memory" clobber here because this inline asm is used for async copy
|
||||
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
|
||||
}
|
||||
|
||||
// NOTE: this is an immediate value
|
||||
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_up(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_down(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
{
|
||||
#if 0
|
||||
return __shfl(v_local, src_lane);
|
||||
#elif 1
|
||||
if constexpr(sizeof(int32_t) > sizeof(T))
|
||||
{
|
||||
union packet
|
||||
{
|
||||
int32_t x;
|
||||
T v;
|
||||
};
|
||||
packet p;
|
||||
p.v = v_local;
|
||||
packet p_remote;
|
||||
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
|
||||
|
||||
return p_remote.v;
|
||||
}
|
||||
else if constexpr(sizeof(int32_t) == sizeof(T))
|
||||
{
|
||||
const int32_t v_remote_tmp =
|
||||
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
|
||||
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
|
||||
using vector_type = thread_buffer<int32_t, elm>;
|
||||
auto vs = bit_cast<vector_type>(v_local);
|
||||
auto vs_remote = vector_type{};
|
||||
static_for<0, elm, 1>{}([&](auto i_e) {
|
||||
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
|
||||
vs_remote(i_e) = tmp;
|
||||
});
|
||||
return bit_cast<T>(vs_remote);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
// per-thread v_flag store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_flag] "v"(v_flag));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
|
||||
{
|
||||
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
|
||||
// per-thread cmp store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_x] "v"(x), [v_y] "v"(y));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
261
include/ck_tile/core/config.hpp
Normal file
261
include/ck_tile/core/config.hpp
Normal file
@@ -0,0 +1,261 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
|
||||
defined(__gfx10_3_generic__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
|
||||
defined(__gfx1152__) || defined(__gfx11_generic__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
#include "hip/hip_version.h"
|
||||
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST inline __host__
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
// implementing the "memory address space" attribute
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
|
||||
// WA for https://github.com/ROCm/composable_kernel/issues/1946
|
||||
#if 0
|
||||
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
|
||||
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
|
||||
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
|
||||
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
|
||||
#else
|
||||
#define CK_TILE_GENERIC_ADDR
|
||||
#define CK_TILE_GLOBAL_ADDR
|
||||
#define CK_TILE_LDS_ADDR
|
||||
#define CK_TILE_BUF_RES_ADDR
|
||||
#endif
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
|
||||
#endif
|
||||
|
||||
// in the old rocm period, we have to use tuple array implementation to implement this
|
||||
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
|
||||
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
|
||||
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
#endif
|
||||
|
||||
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
|
||||
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
|
||||
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
|
||||
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
|
||||
// if using tuple-array as thread_buffer implementation, need to support {} brace init
|
||||
// ... with similiar behavior as array
|
||||
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1
|
||||
#else
|
||||
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
|
||||
#define CK_TILE_USE_LAUNCH_BOUNDS 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_TIME_KERNEL
|
||||
#define CK_TILE_TIME_KERNEL 1
|
||||
#endif
|
||||
|
||||
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
|
||||
#define CK_TILE_MIN_BLOCK_PER_CU 2
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_LOAD
|
||||
#define CK_TILE_USE_AMD_BUFFER_LOAD 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_STORE
|
||||
#define CK_TILE_USE_AMD_BUFFER_STORE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
#endif
|
||||
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// workaround for ROCm 6.2 and later
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
|
||||
defined(__gfx9__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx103__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
|
||||
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
// TODO: better solve this inside compiler
|
||||
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
|
||||
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
|
||||
#define CK_TILE_WORKAROUND_SWDEV_383542 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#ifdef CK_TILE_USE_OCP_FP8
|
||||
#define CK_TILE_USE_OCP_FP8 1
|
||||
#else
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_USE_OCP_FP8 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ == 20
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WA_ISSUE_2028
|
||||
#define CK_TILE_WA_ISSUE_2028 1
|
||||
#endif
|
||||
262
include/ck_tile/core/container/array.hpp
Normal file
262
include/ck_tile/core/container/array.hpp
Normal file
@@ -0,0 +1,262 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// use aggregate initialization for this type
|
||||
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
|
||||
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
|
||||
// use make_array_with({...}) to construct an array with compatible behavior as old ck
|
||||
// TODO: manually added constructor same as old ck
|
||||
template <typename T_, index_t N_>
|
||||
struct array
|
||||
{
|
||||
using value_type = T_;
|
||||
static constexpr index_t N = N_;
|
||||
// TODO: do we need this?
|
||||
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
|
||||
// union {
|
||||
value_type data[N];
|
||||
// bulk_type __content;
|
||||
//};
|
||||
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
|
||||
// TODO: will initialize the data[] with the last value repeatedly
|
||||
// behavior different from std
|
||||
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
|
||||
{
|
||||
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
|
||||
static_assert(list_size <= N, "out of bound");
|
||||
|
||||
index_t i = 0;
|
||||
value_type vlast = value_type{};
|
||||
|
||||
for(const value_type& val : ilist)
|
||||
{
|
||||
data[i] = val;
|
||||
vlast = val;
|
||||
++i;
|
||||
}
|
||||
for(; i < N; ++i)
|
||||
{
|
||||
data[i] = vlast;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y,
|
||||
typename = std::enable_if_t<std::is_convertible_v<Y, value_type> ||
|
||||
std::is_constructible_v<Y, value_type>>>
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
|
||||
{
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = static_cast<value_type>(c);
|
||||
}
|
||||
|
||||
// template <typename Y>
|
||||
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// }
|
||||
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
|
||||
// {
|
||||
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
// __content = o.__content;
|
||||
// return *this;
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
|
||||
// clang-format off
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return get(i); } // TODO: compatible
|
||||
#if 0
|
||||
template <typename ArrayLike>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr)
|
||||
{
|
||||
static_assert(ArrayLike::size() == size(), "wrong! size not the same");
|
||||
for(index_t i = 0; i < size(); ++i)
|
||||
{
|
||||
data[i] = arr[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
// type punning (strict aliasing) member functions for read/write
|
||||
// aliasing this array of type "T", "N" elements
|
||||
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
|
||||
#define AR_AS_COM_() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
|
||||
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
|
||||
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
|
||||
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
#undef AR_AS_COM_
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// empty Array
|
||||
|
||||
template <typename T>
|
||||
struct array<T, 0>
|
||||
{
|
||||
using value_type = T;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr array() {}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<array<T, N>, void>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
namespace details {
|
||||
template <class>
|
||||
struct is_ref_wrapper : std::false_type
|
||||
{
|
||||
};
|
||||
template <class T>
|
||||
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
|
||||
|
||||
template <class D, class...>
|
||||
struct return_type_helper
|
||||
{
|
||||
using type = D;
|
||||
};
|
||||
template <class... Ts>
|
||||
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
|
||||
{
|
||||
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
|
||||
"Ts cannot contain reference_wrappers when D is void");
|
||||
};
|
||||
|
||||
template <class D, class... Ts>
|
||||
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
|
||||
} // namespace details
|
||||
|
||||
template <typename D = void, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
|
||||
{
|
||||
return {std::forward<Ts>(ts)...};
|
||||
}
|
||||
|
||||
// // make empty array
|
||||
// template <typename T>
|
||||
// CK_TILE_HOST_DEVICE constexpr auto make_array()
|
||||
// {
|
||||
// return array<T, 0>{};
|
||||
// }
|
||||
|
||||
// compatible with old ck's initializer, make an array and fill it withe the last element from
|
||||
// initializer_list
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
|
||||
{
|
||||
return array<T, Size>(ilist);
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
for(index_t i = 0; i < Size; ++i)
|
||||
{
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename T, index_t Size>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
|
||||
{
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
|
||||
{
|
||||
static_assert(N <= X::size(), "");
|
||||
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
499
include/ck_tile/core/container/container_helper.hpp
Normal file
499
include/ck_tile/core/container/container_helper.hpp
Normal file
@@ -0,0 +1,499 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/map.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
|
||||
{
|
||||
array<TData, NSize + 1> r;
|
||||
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
|
||||
r[number<NSize>{}] = x;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(make_tuple(x), a);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(a, make_tuple(x));
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder array
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& new2old)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [new_pos, old_pos] : new2old)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_old2new(const array<TData, NSize>& old_array,
|
||||
const map<index_t, index_t>& old2new)
|
||||
{
|
||||
array<TData, NSize> new_array;
|
||||
|
||||
for(const auto& [old_pos, new_pos] : old2new)
|
||||
{
|
||||
new_array(new_pos) = old_array[old_pos];
|
||||
}
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
// reorder tuple
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return make_tuple(old_tuple[number<IRs>{}]...);
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
|
||||
sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
// reorder sequence
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
|
||||
sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t... IRs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
|
||||
sequence<IRs...> /* old2new */)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
|
||||
|
||||
return container_reorder_given_new2old(old_seq, new2old);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i + number<IStep>{}, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<IBegin>{}, init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename ROld,
|
||||
index_t I,
|
||||
index_t IEnd,
|
||||
index_t IStep>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
|
||||
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
}
|
||||
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
// container reduce with initial value
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::size(),
|
||||
index_t IStep = 1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
number<IBegin> = number<0>{},
|
||||
number<IEnd> = number<Container::size()>{},
|
||||
number<IStep> = number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
if constexpr(IEnd > IBegin)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return init;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
|
||||
{
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
|
||||
{
|
||||
#if 0
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
});
|
||||
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
#else
|
||||
array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
for(index_t i = NSize - 1; i > 0; --i)
|
||||
{
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
}
|
||||
|
||||
y(0) = r;
|
||||
|
||||
return y;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t... Is, typename Reduce, index_t Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
|
||||
{
|
||||
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// rocm4.1 compiler would crash with recursive lambda
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
|
||||
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
return container_reverse_exclusive_scan_impl(
|
||||
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
|
||||
template <typename... Xs, typename Reduce, typename TData>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[number<0>{}]);
|
||||
y(number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
|
||||
{
|
||||
return container_concat(x, container_concat(ys...));
|
||||
}
|
||||
|
||||
template <typename T, index_t NX, index_t NY>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_array<T>(arr[Is]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<T, 0>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
return make_tuple(tup[number<Is>{}]...);
|
||||
}
|
||||
else
|
||||
{
|
||||
return tuple<>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
for(index_t i = 0; i < picks.size(); ++i)
|
||||
{
|
||||
y(picks[i]) = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Y, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
|
||||
{
|
||||
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
|
||||
|
||||
if constexpr(sizeof...(Is) > 0)
|
||||
{
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
}
|
||||
|
||||
// return the index of first occurance in the sequence.
|
||||
// return seq.size(), if not found
|
||||
template <index_t... Is>
|
||||
constexpr index_t container_find(sequence<Is...> seq, index_t value)
|
||||
{
|
||||
for(auto i = 0; i < seq.size(); i++)
|
||||
{
|
||||
if(seq[i] == value)
|
||||
return i;
|
||||
}
|
||||
|
||||
return seq.size();
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
{
|
||||
using Seq = sequence<Is...>;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Seq::at(i);
|
||||
return number<tmp>{};
|
||||
},
|
||||
number<Seq::size()>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, a_size, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#else
|
||||
// constexpr index_t can't be captured "-Wunused-lambda-capture"
|
||||
// TODO: this is ugly
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
164
include/ck_tile/core/container/map.hpp
Normal file
164
include/ck_tile/core/container/map.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// naive map
|
||||
template <typename key, typename data, index_t max_size = 128>
|
||||
struct map
|
||||
{
|
||||
using pair_type = tuple<key, data>;
|
||||
using impl_type = array<pair_type, max_size>;
|
||||
|
||||
impl_type impl_;
|
||||
index_t size_;
|
||||
|
||||
struct iterator
|
||||
{
|
||||
impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
struct const_iterator
|
||||
{
|
||||
const impl_type& impl_;
|
||||
index_t pos_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
|
||||
: impl_{impl}, pos_{pos}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
|
||||
{
|
||||
pos_++;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
|
||||
{
|
||||
return other.pos_ != pos_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const
|
||||
{
|
||||
for(index_t i = 0; i < size(); i++)
|
||||
{
|
||||
if(impl_[i].template at<0>() == k)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return size_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const
|
||||
{
|
||||
return const_iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator find(const key& k)
|
||||
{
|
||||
return iterator{impl_, find_position(k)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const
|
||||
{
|
||||
const auto it = find(k);
|
||||
|
||||
// FIXME
|
||||
// assert(it.pos_ < size());
|
||||
|
||||
return impl_[it.pos_].template at<1>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k)
|
||||
{
|
||||
auto it = find(k);
|
||||
|
||||
// if entry not found
|
||||
if(it.pos_ == size())
|
||||
{
|
||||
impl_(it.pos_).template at<0>() = k;
|
||||
size_++;
|
||||
}
|
||||
|
||||
// FIXME
|
||||
// assert(size_ <= max_size);
|
||||
|
||||
return impl_(it.pos_).template at<1>();
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
|
||||
{
|
||||
return const_iterator{impl_, size_};
|
||||
}
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
|
||||
|
||||
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("map{size_: %d, ", size_);
|
||||
//
|
||||
printf("impl_: [");
|
||||
//
|
||||
for(const auto& [k, d] : *this)
|
||||
{
|
||||
printf("{key: ");
|
||||
print(k);
|
||||
printf(", data: ");
|
||||
print(d);
|
||||
printf("}, ");
|
||||
}
|
||||
//
|
||||
printf("]");
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
99
include/ck_tile/core/container/meta_data_buffer.hpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include <cstddef>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: this structure is not intented to be used by user
|
||||
template <index_t MaxSize>
|
||||
struct meta_data_buffer
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
|
||||
: buffer_{}, size_{0}
|
||||
{
|
||||
push(x, xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
|
||||
{
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
auto tmp = ck_tile::bit_cast<array<std::byte, size>>(data);
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
buffer_(size_) = tmp[i];
|
||||
|
||||
size_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
|
||||
{
|
||||
push(x);
|
||||
push(xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
|
||||
{
|
||||
T data;
|
||||
|
||||
if constexpr(!std::is_empty_v<T>)
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
data = ck_tile::bit_cast<T>(tmp);
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
|
||||
{
|
||||
constexpr index_t size = sizeof(T);
|
||||
|
||||
array<std::byte, size> tmp;
|
||||
|
||||
for(int i = 0; i < size; i++)
|
||||
{
|
||||
tmp(i) = buffer_[pos];
|
||||
|
||||
pos++;
|
||||
}
|
||||
|
||||
auto data = ck_tile::bit_cast<T>(tmp);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
//
|
||||
array<std::byte, MaxSize> buffer_;
|
||||
index_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
100
include/ck_tile/core/container/multi_index.hpp
Normal file
100
include/ck_tile/core/container/multi_index.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Don't use tihs directly. This is for old CK's internal usage,
|
||||
// in the future always use array instead
|
||||
template <index_t N>
|
||||
using multi_index = array<index_t, N>;
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
|
||||
{
|
||||
return make_array<index_t>(index_t{xs}...);
|
||||
}
|
||||
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
|
||||
{
|
||||
return unpack([](auto... xs) { return make_multi_index(xs...); },
|
||||
typename uniform_sequence_gen<NSize, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
|
||||
{
|
||||
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
|
||||
{
|
||||
using type = multi_index<NSize>;
|
||||
static_assert(T::size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = index_t * multi_index
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
|
||||
{
|
||||
multi_index<NSize> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// multi_index = multi_index * index_t
|
||||
template <index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1236
include/ck_tile/core/container/sequence.hpp
Normal file
1236
include/ck_tile/core/container/sequence.hpp
Normal file
File diff suppressed because it is too large
Load Diff
78
include/ck_tile/core/container/span.hpp
Normal file
78
include/ck_tile/core/container/span.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <cstddef>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
|
||||
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
|
||||
// TODO: do we need in device consider this is pointer?
|
||||
template <typename T>
|
||||
class span
|
||||
{
|
||||
public:
|
||||
using element_type = T;
|
||||
using value_type = std::remove_cv_t<element_type>;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = element_type*;
|
||||
using const_pointer = const element_type*;
|
||||
using reference = element_type&;
|
||||
using const_reference = const element_type&;
|
||||
using iterator = pointer;
|
||||
using const_iterator = pointer;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
|
||||
{
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
|
||||
: span(arr.data(), N)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
|
||||
: span(container.data(), container.size())
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
|
||||
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
|
||||
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
|
||||
{
|
||||
return *(begin() + idx);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
|
||||
|
||||
private:
|
||||
pointer ptr_;
|
||||
size_type size_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
41
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
41
include/ck_tile/core/container/statically_indexed_array.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
|
||||
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = tuple_array<T, N>;
|
||||
|
||||
#else
|
||||
|
||||
// consider mark this struct as deprecated
|
||||
template <typename T, index_t N>
|
||||
using statically_indexed_array = array<T, N>;
|
||||
|
||||
#endif
|
||||
|
||||
// consider always use ck_tile::array for this purpose
|
||||
#if 0
|
||||
template <typename X, typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
||||
}
|
||||
|
||||
// make empty statically_indexed_array
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return statically_indexed_array<X, 0>();
|
||||
}
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
172
include/ck_tile/core/container/thread_buffer.hpp
Normal file
172
include/ck_tile/core/container/thread_buffer.hpp
Normal file
@@ -0,0 +1,172 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = tuple_array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_tuple(ts...);
|
||||
}
|
||||
#else
|
||||
|
||||
#if 0
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = array<T, N>;
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_array(ts...);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
template<typename T_, index_t N_>
|
||||
struct thread_buffer {
|
||||
using value_type = remove_cvref_t<T_>;
|
||||
static constexpr index_t N = N_;
|
||||
|
||||
value_type data[N];
|
||||
|
||||
// TODO: this ctor can't ignore
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE auto & get() {return data; }
|
||||
CK_TILE_HOST_DEVICE const auto & get() const {return data; }
|
||||
CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
template <typename X_,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
static_assert(N % kSPerX == 0);
|
||||
|
||||
union {
|
||||
thread_buffer<X_, N / kSPerX> data {};
|
||||
// tuple_array<value_type, kSPerX> sub_data;
|
||||
value_type sub_data[N];
|
||||
} vx;
|
||||
static_for<0, N, 1>{}(
|
||||
[&](auto j) { vx.sub_data[j] = data[j]; });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data {};
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx;
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
|
||||
return vx.data;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <typename X_,
|
||||
index_t Is,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
|
||||
{
|
||||
using X = remove_cvref_t<X_>;
|
||||
|
||||
constexpr index_t kSPerX = vector_traits<X>::vector_size;
|
||||
|
||||
union {
|
||||
X_ data;
|
||||
tuple_array<value_type, kSPerX> sub_data;
|
||||
} vx {x};
|
||||
|
||||
static_for<0, kSPerX, 1>{}(
|
||||
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define TB_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
template<typename Tx, index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
|
||||
if constexpr(sizeof(value_type) <= 1 )
|
||||
return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
|
||||
else
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
|
||||
#undef TB_COMMON_AS
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<!std::is_class_v<T>>>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<std::is_class_v<T>>>
|
||||
{
|
||||
using scalar_type = typename T::type;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
830
include/ck_tile/core/container/tuple.hpp
Normal file
830
include/ck_tile/core/container/tuple.hpp
Normal file
@@ -0,0 +1,830 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <utility>
|
||||
#include <initializer_list>
|
||||
|
||||
#ifndef CK_TILE_TUPLE_IMPL
|
||||
#define CK_TILE_TUPLE_IMPL 1
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl;
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
using tuple_array = typename impl::tuple_array_impl<T, N>::type;
|
||||
|
||||
namespace impl {
|
||||
|
||||
// the place where content is stored
|
||||
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
|
||||
struct tuple_object
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, true>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U&)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_object<idx, T, false>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e)
|
||||
{
|
||||
}
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e)
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <typename U,
|
||||
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
|
||||
{
|
||||
}
|
||||
#endif
|
||||
T element;
|
||||
};
|
||||
|
||||
// NOTE: we return a instance(not a reference) if content is empty
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <index_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
|
||||
{
|
||||
return static_cast<T&&>(x.element);
|
||||
}
|
||||
|
||||
template <typename index_seq, typename... T>
|
||||
struct tuple_base;
|
||||
|
||||
template <index_t... I, typename... T>
|
||||
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
|
||||
|
||||
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
#define _ILE() (std::initializer_list<U>{}.size() - 1)
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list<U> us)
|
||||
: tuple_object<I, T>(static_cast<T>(*(us.begin() + (I >= _ILE() ? _ILE() : I))))...
|
||||
{
|
||||
}
|
||||
#undef _ILE
|
||||
#endif
|
||||
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
|
||||
: tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>&& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <class U,
|
||||
typename std::enable_if<sizeof...(I) == 1 && sizeof...(T) == 1 &&
|
||||
!std::is_same<remove_cvref_t<U>, tuple_base>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U, typename std::enable_if<sizeof...(U) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object<I, T>(std::forward<U>(u))...
|
||||
{
|
||||
static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U),
|
||||
"wrong! inconsistent size");
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <class... T>
|
||||
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr auto size() { return sizeof...(T); }
|
||||
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
|
||||
CK_TILE_HOST_DEVICE constexpr tuple() = default;
|
||||
|
||||
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
|
||||
template <typename U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_TUPLE_IMPL == 0
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>&& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
|
||||
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>& u)
|
||||
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
#elif CK_TILE_TUPLE_IMPL == 1
|
||||
template <
|
||||
typename U,
|
||||
typename std::enable_if<sizeof...(T) == 1 && !std::is_same<remove_cvref_t<U>, tuple>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward<U>(u))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... U,
|
||||
typename std::enable_if<sizeof...(U) == sizeof...(T) && sizeof...(U) >= 2,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, sizeof...(T), 1>{}([&flag](auto i) {
|
||||
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, T...>>>;
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
|
||||
// clang-format off
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
|
||||
// below function should be used under tuple_array<> type, no extra check will perform here
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
|
||||
// below index is for index *AFTER* type convert, not before
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
|
||||
|
||||
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}) = x; }
|
||||
|
||||
// clang-format on
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename... T>
|
||||
struct vector_traits<tuple<T...>>
|
||||
{
|
||||
using scalar_type = __type_pack_element<0, T...>;
|
||||
static constexpr index_t vector_size = sizeof...(T);
|
||||
};
|
||||
|
||||
// template <class... T>
|
||||
// CK_TILE_HOST_DEVICE constexpr
|
||||
// tuple<T...>
|
||||
// make_tuple(T const&... t)
|
||||
// {
|
||||
// return {t...};
|
||||
// }
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
bool same = true;
|
||||
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
|
||||
if(a[i] != b[i])
|
||||
{
|
||||
same = false;
|
||||
}
|
||||
});
|
||||
|
||||
return same;
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
|
||||
{
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
// here xs is always a lvalue as function arg
|
||||
// Xs may deduced as (e.g try to pass in a integer in following cases)
|
||||
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
|
||||
// 2). if pass in a const lvalue -> Xs is "const int &"
|
||||
// 3). if pass in a non-const lvalue -> Xs is "int &"
|
||||
// so the return type of std::forward will dependes on Xs
|
||||
// 1). std::forward -> int&&
|
||||
// 2). std::forward -> const int&
|
||||
// 3). std::forward -> int&
|
||||
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/tie
|
||||
template <typename... Args>
|
||||
constexpr tuple<Args&...> tie(Args&... args) noexcept
|
||||
{
|
||||
return {args...};
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
|
||||
{
|
||||
using type = tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// be very careful using this type (because we want the internal type)
|
||||
// template deduction will fail if infering the inner type
|
||||
// e.g.
|
||||
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
|
||||
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
|
||||
// -> compiler will fail to deduce this type, because this is under non-deduced context
|
||||
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
|
||||
// contexts")
|
||||
//
|
||||
// -> use this instead
|
||||
// template<typename Tup> void foo(const Tup&) {}
|
||||
template <typename T, index_t N>
|
||||
struct tuple_array_impl
|
||||
{
|
||||
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
|
||||
typename tuple_array_impl<T, N - N / 2>::type>::type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 0>
|
||||
{
|
||||
using type = tuple<>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct tuple_array_impl<T, 1>
|
||||
{
|
||||
using type = tuple<T>;
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename F, index_t... ids>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple_for(F&& f, sequence<ids...>)
|
||||
{
|
||||
return make_tuple(f(number<ids>{})...);
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
|
||||
{
|
||||
return generate_tuple_for(f, make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... is) { return tie(f(is)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
|
||||
const tuple<Y&...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
// Support any number of tuples to concat (also 1)
|
||||
template <typename... X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
|
||||
{
|
||||
return tx;
|
||||
}
|
||||
|
||||
template <typename... X, typename... Tuples>
|
||||
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
|
||||
{
|
||||
return concat_tuple(tx, concat_tuple(tuples...));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return concat_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make sure F return at least a tuple
|
||||
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
|
||||
// this function will return
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::embed_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t)
|
||||
{
|
||||
return make_tuple(t);
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& t)
|
||||
{
|
||||
if constexpr(Depth == MaxDepth)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
else
|
||||
{
|
||||
return unpack(
|
||||
[&](auto&&... ts) {
|
||||
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
|
||||
},
|
||||
t);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using Idx = number<tuple<Ts...>::size() - i - 1>;
|
||||
return t.at(Idx{});
|
||||
},
|
||||
number<tuple<Ts...>::size()>{});
|
||||
}
|
||||
|
||||
// Reduce tuple values in specific range using Function
|
||||
template <index_t Idx, index_t End, typename F, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& t)
|
||||
{
|
||||
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
|
||||
if constexpr(Idx + 1 == End)
|
||||
{
|
||||
return t.at(number<Idx>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return f(t.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, t));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
|
||||
{
|
||||
return (is_detected<is_tuple, Ts>::value || ...);
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
|
||||
{
|
||||
return depth;
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
|
||||
{
|
||||
return max(tuple_depth<depth + 1>(Ts{})...);
|
||||
}
|
||||
|
||||
template <typename... Seqs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
|
||||
{
|
||||
constexpr index_t n0 = sizeof...(Seqs);
|
||||
|
||||
constexpr index_t max_n1 = [&] {
|
||||
index_t max_n1_ = 0;
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
|
||||
});
|
||||
|
||||
return max_n1_;
|
||||
}();
|
||||
|
||||
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
|
||||
});
|
||||
|
||||
return a_of_a;
|
||||
}
|
||||
|
||||
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
|
||||
// is the alias of the latter. This is because compiler cannot infer the NSize if
|
||||
// using MultiIndex<NSize>
|
||||
// TODO: how to fix this?
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename X,
|
||||
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = scalar * MultiIndex
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = MultiIndex * scalar
|
||||
template <
|
||||
typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
|
||||
{
|
||||
return a * x;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include <tuple>
|
||||
// WARNING: needed by compiler for C++ structured binding support only, don't use this
|
||||
namespace std {
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, ck_tile::tuple<Ts...>> : std::tuple_element<I, std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t I, typename... Ts>
|
||||
struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
: std::tuple_element<I, const std::tuple<Ts...>>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#if 1
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
|
||||
#else
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
423
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
423
include/ck_tile/core/numeric/bfloat16.hpp
Normal file
@@ -0,0 +1,423 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class bf16_rounding_mode
|
||||
{
|
||||
standard = 0, // rtn
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
standard_asm,
|
||||
rta_asm, // round to nearest away
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double_raw(uint16_t x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr bfloat16_t bit_cast(raw_type x)
|
||||
{
|
||||
bfloat16_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
// constructor
|
||||
constexpr bfloat16_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr bfloat16_t(const unsigned int& x)
|
||||
: data(float_to_bf16_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<bfloat16_t>
|
||||
{
|
||||
using type = ushort;
|
||||
};
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
// round to nearest
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
if(~u.int32 & 0x7f800000)
|
||||
{
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
|
||||
}
|
||||
else if(u.int32 & 0xffff)
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
u.int32 |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
CK_TILE_HOST
|
||||
constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
uint16_t float_to_bf16_rtn_asm(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
|
||||
static constexpr uint32_t FP32_NAN = 0x7fff0000;
|
||||
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
|
||||
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
uint32_t tmp;
|
||||
asm volatile("\n \
|
||||
v_cmp_u_f32 %0, %2, %2 \n \
|
||||
v_bfe_u32 %1, %2, 16, 1 \n \
|
||||
v_add3_u32 %1, %2, %1, %3 \n \
|
||||
v_cndmask_b32 %2, %1, %4, %0 \n \
|
||||
v_lshrrev_b32 %2, 16, %2 \n \
|
||||
"
|
||||
: "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
|
||||
: "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
|
||||
|
||||
return uint16_t(u.int32);
|
||||
}
|
||||
|
||||
// TODO: do we need this on host?
|
||||
CK_TILE_HOST
|
||||
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
uint16_t float_to_bf16_rta_asm(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
struct
|
||||
{
|
||||
uint16_t lo;
|
||||
uint16_t hi;
|
||||
};
|
||||
} u = {f};
|
||||
|
||||
const uint32_t low_nan = 0x7fff;
|
||||
const uint32_t hi_nan = 0x7fff0000;
|
||||
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
|
||||
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
|
||||
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
|
||||
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
|
||||
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
|
||||
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
|
||||
|
||||
// Note: in above code snipet, we use hi 16 bit
|
||||
return u.hi;
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
|
||||
}
|
||||
|
||||
// Fast truncate instead of rounding, RTZ
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
|
||||
return float_to_bf16_rtn_asm(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
|
||||
return float_to_bf16_rta_asm(f);
|
||||
else
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
{
|
||||
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(x) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double_raw(uint16_t x)
|
||||
{
|
||||
return static_cast<double>(bf16_to_float_raw(x));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
|
||||
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<bfloat16_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// maximum rounding error
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeeeeee mmmmmmm
|
||||
// 0 01111110 0000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<bfloat16_t>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 7;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
bfloat16_t abs(const bfloat16_t& x)
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bfloat16_t& x)
|
||||
{
|
||||
uint16_t xx = bit_cast<bf16_raw_t>(x);
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t sqrt(bfloat16_t x)
|
||||
{
|
||||
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp(bfloat16_t x)
|
||||
{
|
||||
return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
} // namespace ck_tile
|
||||
1123
include/ck_tile/core/numeric/float8.hpp
Normal file
1123
include/ck_tile/core/numeric/float8.hpp
Normal file
File diff suppressed because it is too large
Load Diff
404
include/ck_tile/core/numeric/half.hpp
Normal file
404
include/ck_tile/core/numeric/half.hpp
Normal file
@@ -0,0 +1,404 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp16_hip_t = _Float16; // most of hip internal function use this type
|
||||
using fp16_raw_t = uint16_t;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
using raw_type = fp16_raw_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
static constexpr half_t bit_cast(raw_type x)
|
||||
{
|
||||
half_t y;
|
||||
y.data = x;
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// constructor
|
||||
constexpr half_t() : data{} {}
|
||||
|
||||
// construct from HIP half
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast<raw_type>(x)) {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr half_t(const unsigned int& x)
|
||||
: half_t(static_cast<fp16_hip_t>(__uint2half_rn(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
|
||||
|
||||
// cast to double
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator int() const
|
||||
{
|
||||
return static_cast<int>(fp16_to_float_hip(to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<half_t>
|
||||
{
|
||||
using type = _Float16;
|
||||
};
|
||||
|
||||
using fp16_t = half_t;
|
||||
using fp16_raw_t = typename half_t::raw_type;
|
||||
#else
|
||||
using fp16_t = _Float16;
|
||||
using half_t = _Float16;
|
||||
using fp16_raw_t = ushort;
|
||||
#endif
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
{
|
||||
// return __half2float(x);
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x)
|
||||
{
|
||||
return static_cast<double>(fp16_to_float_hip(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t float_to_fp16(const float& x) { return static_cast<half_t>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t double_to_fp16(const double& x) { return static_cast<half_t>(x); }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<half_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t max()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7BFF));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t epsilon()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x1800));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeee mmmmmmmmmm
|
||||
// 0 01110 0000000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x3800));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7C00));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0001));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t zero()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint16_t abs_mask = 0x7FFF;
|
||||
static constexpr uint16_t Inf = 0x7C00;
|
||||
static constexpr uint16_t NegInf = 0xFC00;
|
||||
static constexpr uint16_t NaN = 0x7C01;
|
||||
static constexpr uint16_t Neg0 = 0x8000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// arithmetic
|
||||
CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y)
|
||||
{
|
||||
return __heq(x.to_fp16(), y.to_fp16());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
#if 0
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator-(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator*(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator/(const half_t& x, const half_t& y)
|
||||
{
|
||||
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator+=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator-=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator*=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator/=(half_t& x, const half_t& y)
|
||||
{
|
||||
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator++(half_t& x)
|
||||
{
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t& operator--(half_t& x)
|
||||
{
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return x;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator++(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t operator--(half_t& x, int)
|
||||
{
|
||||
half_t y(x);
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return bit_cast<half_t>(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const half_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp(half_t x) { return static_cast<half_t>(__ocml_exp_f32(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
#endif
|
||||
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t vector_res;
|
||||
|
||||
vector_res.x = x.x + y.x;
|
||||
vector_res.y = x.y + y.y;
|
||||
|
||||
return vector_res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t c;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
|
||||
return c;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
103
include/ck_tile/core/numeric/int8.hpp
Normal file
103
include/ck_tile/core/numeric/int8.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// use int8_t directly for int8 arithemetic
|
||||
// here one can use ck_tile::int8_t to access original int8_t
|
||||
using int8_t = int8_t;
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<int8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
|
||||
};
|
||||
|
||||
#if 0
|
||||
|
||||
template <>
|
||||
struct numeric_traits<int8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
|
||||
|
||||
} // namespace ck_tile
|
||||
13
include/ck_tile/core/numeric/integer.hpp
Normal file
13
include/ck_tile/core/numeric/integer.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
82
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
82
include/ck_tile/core/numeric/integral_constant.hpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <auto v>
|
||||
struct constant
|
||||
{
|
||||
using value_type = decltype(v);
|
||||
using type = constant; // using injected-class-name
|
||||
static constexpr value_type value = v;
|
||||
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
template <typename T, T v>
|
||||
struct integral_constant : constant<v>
|
||||
{
|
||||
using value_type = T;
|
||||
using type = integral_constant; // using injected-class-name
|
||||
static constexpr T value = v;
|
||||
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
|
||||
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
|
||||
};
|
||||
|
||||
template <index_t v>
|
||||
using number = constant<v>;
|
||||
|
||||
template <long_index_t v>
|
||||
using long_number = constant<v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<b>;
|
||||
|
||||
#define CK_TILE_LEFT_UNARY_OP(OP) \
|
||||
template <auto x> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
|
||||
{ \
|
||||
return constant<(OP x)>{}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_BINARY_OP(OP) \
|
||||
template <auto x, auto y> \
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
|
||||
{ \
|
||||
return constant<(x OP y)>{}; \
|
||||
}
|
||||
|
||||
CK_TILE_LEFT_UNARY_OP(+)
|
||||
CK_TILE_LEFT_UNARY_OP(-)
|
||||
CK_TILE_LEFT_UNARY_OP(~)
|
||||
CK_TILE_LEFT_UNARY_OP(!)
|
||||
|
||||
CK_TILE_BINARY_OP(+)
|
||||
CK_TILE_BINARY_OP(-)
|
||||
CK_TILE_BINARY_OP(*)
|
||||
CK_TILE_BINARY_OP(/)
|
||||
CK_TILE_BINARY_OP(%)
|
||||
CK_TILE_BINARY_OP(&)
|
||||
CK_TILE_BINARY_OP(|)
|
||||
CK_TILE_BINARY_OP(^)
|
||||
CK_TILE_BINARY_OP(<<)
|
||||
CK_TILE_BINARY_OP(>>)
|
||||
CK_TILE_BINARY_OP(&&)
|
||||
CK_TILE_BINARY_OP(||)
|
||||
CK_TILE_BINARY_OP(==)
|
||||
CK_TILE_BINARY_OP(!=)
|
||||
CK_TILE_BINARY_OP(>)
|
||||
CK_TILE_BINARY_OP(<)
|
||||
CK_TILE_BINARY_OP(>=)
|
||||
CK_TILE_BINARY_OP(<=)
|
||||
|
||||
#undef CK_TILE_LEFT_UNARY_OP
|
||||
#undef CK_TILE_BINARY_OP
|
||||
|
||||
} // namespace ck_tile
|
||||
1443
include/ck_tile/core/numeric/math.hpp
Normal file
1443
include/ck_tile/core/numeric/math.hpp
Normal file
File diff suppressed because it is too large
Load Diff
13
include/ck_tile/core/numeric/null_type.hpp
Normal file
13
include/ck_tile/core/numeric/null_type.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct null_type
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
196
include/ck_tile/core/numeric/numeric.hpp
Normal file
196
include/ck_tile/core/numeric/numeric.hpp
Normal file
@@ -0,0 +1,196 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this struct has the information of
|
||||
// 1. limit of a certain type, simliar to std::numeric_limits
|
||||
// 2. some pre-defined value, zero, one...
|
||||
//
|
||||
template <typename T>
|
||||
struct numeric
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr T round_error()
|
||||
{
|
||||
return std::numeric_limits<T>::round_error();
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::signaling_NaN();
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
|
||||
{
|
||||
return std::numeric_limits<T>::denorm_min();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast<T>(0); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast<T>(1); }
|
||||
|
||||
#ifndef C_LOG2E
|
||||
#define C_LOG2E 1.44269504088896340736 // log2(e)
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T log2e()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float> || std::is_same_v<T, double>)
|
||||
{
|
||||
return static_cast<T>(C_LOG2E);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0; // TODO: integer?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits
|
||||
{
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t abs_mask = 0x7FFFFFFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
150
include/ck_tile/core/numeric/pk_int4.hpp
Normal file
150
include/ck_tile/core/numeric/pk_int4.hpp
Normal file
@@ -0,0 +1,150 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Packed 2xint4
|
||||
struct pk_int4_t
|
||||
{
|
||||
using type = int8_t;
|
||||
type data;
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<pk_int4_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
|
||||
{
|
||||
constexpr uint8_t val = 0b10001000;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t lowest()
|
||||
{
|
||||
constexpr uint8_t val = 0b10001000;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
|
||||
{
|
||||
constexpr uint8_t val = 0b01110111;
|
||||
return pk_int4_t(bit_cast<int8_t>(val));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t epsilon()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t round_error()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t infinity()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t quiet_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t signaling_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t denorm_min()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<pk_int4_t>
|
||||
{
|
||||
static constexpr int PackedSize = 2;
|
||||
};
|
||||
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
fp32x2_t res = {x_h, x_l};
|
||||
#elif
|
||||
fp32x2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
|
||||
#elif
|
||||
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
|
||||
#endif
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0xE408E408; //-8
|
||||
|
||||
int lo = i4s | EX;
|
||||
|
||||
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
|
||||
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
|
||||
#elif
|
||||
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
70
include/ck_tile/core/numeric/type_convert.hpp
Normal file
70
include/ck_tile/core/numeric/type_convert.hpp
Normal file
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
|
||||
{
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
#else
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using non_const_y = std::remove_const_t<Y>;
|
||||
using non_const_x = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
240
include/ck_tile/core/numeric/vector_type.hpp
Normal file
240
include/ck_tile/core/numeric/vector_type.hpp
Normal file
@@ -0,0 +1,240 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this structure is used to pick up the <base> type inside
|
||||
// using xxx = <base> __attribute__((ext_vector_type(N)));
|
||||
// because clang only allow native type + bool in this term (custom type will fail)
|
||||
// overload this structure to let proper <base> type
|
||||
|
||||
template <typename T>
|
||||
struct native_t
|
||||
{
|
||||
using type = remove_cvref_t<T>;
|
||||
};
|
||||
|
||||
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
|
||||
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
|
||||
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
|
||||
// have compiler error
|
||||
namespace impl {
|
||||
|
||||
template <typename T_, index_t N_, typename = void>
|
||||
struct ext_vector;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
// struct type is not supported for ext_vector
|
||||
using value_type = typename native_t<T_>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
// struct type is not supported for ext_vector
|
||||
using value_type = typename native_t<T_>::type::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
|
||||
N_,
|
||||
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
|
||||
N_,
|
||||
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
template <typename T, index_t N>
|
||||
using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T, typename>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type =
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// specialization for ext_vector_type()
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
|
||||
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// attention! 2 vector type could be just the same type
|
||||
// fp64
|
||||
using fp64_t = double;
|
||||
using fp64x2_t = double __attribute__((ext_vector_type(2)));
|
||||
using fp64x4_t = double __attribute__((ext_vector_type(4)));
|
||||
|
||||
// fp32
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp32x4_t = float __attribute__((ext_vector_type(4)));
|
||||
using fp32x8_t = float __attribute__((ext_vector_type(8)));
|
||||
using fp32x16_t = float __attribute__((ext_vector_type(16)));
|
||||
using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
// using fp16_t = ...
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
|
||||
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
|
||||
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
|
||||
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u32
|
||||
// using uint32_t = ...
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
|
||||
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
|
||||
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
|
||||
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
|
||||
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i16
|
||||
// using int16_t = ...
|
||||
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
|
||||
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
|
||||
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
|
||||
using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
|
||||
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u16
|
||||
// using uint16_t
|
||||
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
|
||||
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
|
||||
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
|
||||
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
|
||||
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
|
||||
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
// using int8_t
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// ui8
|
||||
// using uint8_t
|
||||
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
|
||||
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
|
||||
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
|
||||
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
|
||||
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
|
||||
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
|
||||
#else
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
|
||||
#endif
|
||||
|
||||
// pk_int4_t
|
||||
// using pk_int4_t
|
||||
using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
} // namespace ck_tile
|
||||
1273
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
1273
include/ck_tile/core/tensor/buffer_view.hpp
Normal file
File diff suppressed because it is too large
Load Diff
203
include/ck_tile/core/tensor/load_tile.hpp
Normal file
203
include/ck_tile/core/tensor/load_tile.hpp
Normal file
@@ -0,0 +1,203 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Loads a tile of data using inline assembly.
|
||||
*
|
||||
* @note Bare in mind that loading data this way, you have to manually initialize your
|
||||
* thread buffer and synchronize load afterwards in order to make sure it's done before
|
||||
* using loaded data from registers
|
||||
* @see `tile_window_with_static_distribution::init_raw()` and `buffer_view.hpp`
|
||||
* @see `buffer_load_fence()`
|
||||
*/
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto
|
||||
async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
template <typename T, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
12
include/ck_tile/core/tensor/null_tensor.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct null_tensor
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
97
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
97
include/ck_tile/core/tensor/null_tile_window.hpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_view.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// placeholder type if we want to opt-out a tile window parameter
|
||||
template <typename WindowLengths_>
|
||||
struct null_tile_window
|
||||
{
|
||||
using BottomTensorView = null_tensor_view;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
|
||||
using BottomTensorIndex = array<index_t, WindowLengths::size()>;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths)
|
||||
: window_lengths_{window_lengths}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
|
||||
|
||||
CK_TILE_DEVICE void init_raw() {}
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
};
|
||||
|
||||
// utility to check if this is a Null Tile Window
|
||||
namespace impl {
|
||||
template <typename>
|
||||
struct is_null_tile_window : public std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_null_tile_window<null_tile_window<T>> : public std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
|
||||
{
|
||||
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename... Ts>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<WindowLengths::size()>& /*origin*/,
|
||||
Ts&&...)
|
||||
{
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window(const null_tile_window<WindowLengths>& t,
|
||||
const StaticTileDistribution&)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(null_tile_window<WindowLengths>&,
|
||||
const typename null_tile_window<WindowLengths>::BottomTensorIndex&)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
177
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
177
include/ck_tile/core/tensor/shuffle_tile.hpp
Normal file
@@ -0,0 +1,177 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using DataType = typename InTensor::DataType;
|
||||
|
||||
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
|
||||
// y_dim_out_to_in
|
||||
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
|
||||
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
|
||||
|
||||
map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
|
||||
|
||||
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
|
||||
|
||||
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
|
||||
});
|
||||
|
||||
return rh_major_minor_to_y_;
|
||||
};
|
||||
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
|
||||
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
|
||||
{
|
||||
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
|
||||
}
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
constexpr index_t y_dim_vec_in = NDimY - 1;
|
||||
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
|
||||
|
||||
// vector lengths
|
||||
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
|
||||
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
|
||||
|
||||
// # of vectors
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
// using InVec = typename InVec::type;
|
||||
// using OutVec = typename OutVec::type;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
|
||||
using SFC_Y = space_filling_curve<decltype(y_lengths),
|
||||
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
|
||||
decltype(scalars_per_access)>;
|
||||
|
||||
constexpr index_t num_access = SFC_Y::get_num_of_access();
|
||||
|
||||
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
|
||||
|
||||
// in/out vectors to be transposed
|
||||
thread_buffer<InVec, num_vec_in> in_vectors;
|
||||
thread_buffer<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
using InDataType = typename InTensor::DataType;
|
||||
using OutDataType = typename OutTensor::DataType;
|
||||
|
||||
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
|
||||
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
|
||||
|
||||
// type convert
|
||||
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
|
||||
|
||||
// shuffle
|
||||
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
|
||||
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
|
||||
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
|
||||
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
|
||||
InDstrEncode::NDimY == OutDstrEncode::NDimY)
|
||||
{
|
||||
detail::shuffle_tile_impl_in_thread(out, in_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "The shuffle should always happen!");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
92
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
92
include/ck_tile/core/tensor/slice_tile.hpp
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using TileWindow = tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>;
|
||||
// NOTE: This API will override the origin of the tile window!
|
||||
static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
|
||||
static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
|
||||
|
||||
constexpr auto slice_lengths = slice_ends - slice_begins;
|
||||
|
||||
return make_tile_window(tile.get_bottom_tensor_view(),
|
||||
sequence_to_tuple_of_number(slice_lengths),
|
||||
to_multi_index(slice_begins));
|
||||
}
|
||||
|
||||
template <typename DataType_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_slice_tile(const static_distributed_tensor<DataType_, StaticTileDistribution_>& tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using Distribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
|
||||
|
||||
sliced_tensor.get_thread_buffer() =
|
||||
tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
|
||||
|
||||
return sliced_tensor;
|
||||
}
|
||||
|
||||
template <typename DstDataType_,
|
||||
typename DstStaticTileDistribution_,
|
||||
typename SrcDataType_,
|
||||
typename SrcStaticTileDistribution_,
|
||||
index_t... SliceBegins,
|
||||
index_t... SliceEnds>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution_>& dst_tile,
|
||||
const static_distributed_tensor<SrcDataType_, SrcStaticTileDistribution_>& src_tile,
|
||||
sequence<SliceBegins...> slice_begins,
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
|
||||
|
||||
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
|
||||
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
235
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
235
include/ck_tile/core/tensor/static_distributed_tensor.hpp
Normal file
@@ -0,0 +1,235 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_, typename StaticTileDistribution_>
|
||||
struct static_distributed_tensor
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
static_assert(StaticTileDistribution::is_static(),
|
||||
"wrong! StaticTileDistribution should be known at compile tile");
|
||||
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
|
||||
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
|
||||
{
|
||||
return StaticTileDistribution::get_num_of_dimension_x();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
return StaticTileDistribution::get_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution()
|
||||
{
|
||||
return StaticTileDistribution{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
return StaticTileDistribution::get_distributed_spans();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize / PackedSize;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths>
|
||||
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>) const
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
|
||||
sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
sliced_thread_data(
|
||||
number<sliced_thread_tensor_desc.calculate_offset(idx) / PackedSize>{}) =
|
||||
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}];
|
||||
});
|
||||
|
||||
return sliced_thread_data;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
|
||||
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence<YSliceOrigins...>,
|
||||
sequence<YSliceLengths...>,
|
||||
const SlicedThreadData& sliced_thread_data)
|
||||
{
|
||||
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
|
||||
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}) =
|
||||
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx) /
|
||||
PackedSize>{}];
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{}];
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
|
||||
{
|
||||
static_assert(is_static_v<TileDistributedIndices>,
|
||||
"wrong! Tile Distributed Indices should be static");
|
||||
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{});
|
||||
}
|
||||
|
||||
//
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
};
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{};
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
|
||||
ThreadBuffer&& thread_buffer_)
|
||||
{
|
||||
return static_distributed_tensor<remove_cvref_t<DataType>,
|
||||
remove_cvref_t<StaticTileDistribution>>{thread_buffer_};
|
||||
}
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices)
|
||||
{
|
||||
const auto partition_index = detail::get_partition_index(tile_distribution);
|
||||
constexpr auto y_indices =
|
||||
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
|
||||
|
||||
const auto x_coord = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
|
||||
|
||||
return x_coord.get_bottom_index();
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
|
||||
DataType value,
|
||||
XIndicesPredicate predicate)
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
|
||||
distributed_indices);
|
||||
|
||||
if(predicate(x_indices))
|
||||
{
|
||||
out_tensor(distributed_indices) = value;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// this function used inside span loop over
|
||||
template <typename YLengths, index_t XUnpacks>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
|
||||
{
|
||||
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
|
||||
constexpr auto y_packs = number<XUnpacks>{};
|
||||
static_assert(y_size % y_packs == 0);
|
||||
constexpr auto y_slice_size = y_size / y_packs;
|
||||
|
||||
constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
|
||||
constexpr auto unpacks = slice_info[number<1>{}];
|
||||
return unpacks;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// check if 2 static_distributed_tensor has same data type and size of element
|
||||
// but only difference in distribution
|
||||
template <typename X, typename Y>
|
||||
struct is_similiar_distributed_tensor
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
|
||||
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
|
||||
static_distributed_tensor<TypeY, DistY>>
|
||||
{
|
||||
using Tx = static_distributed_tensor<TypeX, DistX>;
|
||||
using Ty = static_distributed_tensor<TypeY, DistY>;
|
||||
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
|
||||
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
inline constexpr bool is_similiar_distributed_tensor_v =
|
||||
is_similiar_distributed_tensor<X, Y>::value;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck_tile
|
||||
120
include/ck_tile/core/tensor/store_tile.hpp
Normal file
120
include/ck_tile/core/tensor/store_tile.hpp
Normal file
@@ -0,0 +1,120 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
308
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
308
include/ck_tile/core/tensor/sweep_tile.hpp
Normal file
@@ -0,0 +1,308 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// sweep over a span of a distribted tile and apply lambda function F
|
||||
template <typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
typename F // signature: F(tile_distributed_index<...>)
|
||||
>
|
||||
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
|
||||
{
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
|
||||
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
|
||||
|
||||
f(dstr_idx);
|
||||
});
|
||||
}
|
||||
|
||||
// unpacked span, this version support span with unpack(multi-arg) functor
|
||||
//
|
||||
template <
|
||||
typename TileDistributedSpan_, // tile_distributed_span<...>
|
||||
typename F, // signature: F(tile_distributed_index<...>)
|
||||
typename Unpacks = typename uniform_sequence_gen<TileDistributedSpan_::Impl::size(), 1>::type>
|
||||
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks = {})
|
||||
{
|
||||
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
|
||||
|
||||
static_uford<typename DstrSpan::Impl, Unpacks>{}(
|
||||
[&](auto... dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)...); });
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename, typename, typename>
|
||||
struct sweep_tile_impl;
|
||||
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
|
||||
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
|
||||
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
|
||||
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
|
||||
return y_unpacks;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
return u.get_num_of_access() *
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
}
|
||||
template <typename F, typename SpanIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
sweep_tile_uspan(
|
||||
spans[number<I>{}],
|
||||
[&](auto... i_idx) {
|
||||
const auto next_span_idx = embed_tuples(
|
||||
[&](auto si) { return make_tuple(concat_tuple(si, make_tuple(i_idx))...); },
|
||||
span_idx);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx);
|
||||
},
|
||||
get_y_unpacks());
|
||||
}
|
||||
template <typename F, typename SpanIdx, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
constexpr auto access_stride =
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
constexpr auto curr_i_access = number<i_access / access_stride>{};
|
||||
constexpr auto next_i_access = number<i_access % access_stride>{};
|
||||
u(
|
||||
[&](auto... i_idx) {
|
||||
const auto next_span_idx = embed_tuples(
|
||||
[&](auto si) {
|
||||
return make_tuple(concat_tuple(
|
||||
si, make_tuple(detail::make_tile_distributed_index(i_idx)))...);
|
||||
},
|
||||
span_idx);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx, next_i_access);
|
||||
},
|
||||
curr_i_access);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim>
|
||||
struct sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const { return 1; }
|
||||
template <typename F, typename SpanIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, const SpanIdx& span_idx) const
|
||||
{
|
||||
unpack(f, span_idx);
|
||||
}
|
||||
template <typename F, typename SpanIdx, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(const F& f, const SpanIdx& span_idx, number<i_access>) const
|
||||
{
|
||||
unpack(f, span_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename, typename, typename>
|
||||
struct sweep_tile_impl_0;
|
||||
|
||||
// TODO: support empty tuple to remove this "entry-point" like function
|
||||
template <typename DistributedTensor, typename UnpacksPerXDim, index_t I, index_t... Is>
|
||||
struct sweep_tile_impl_0<DistributedTensor, UnpacksPerXDim, sequence<I, Is...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto y_lengths = typename decltype(spans[number<I>{}])::Impl{};
|
||||
constexpr auto x_unpacks = number<UnpacksPerXDim{}.at(number<I>{})>{};
|
||||
constexpr auto y_unpacks = get_y_unpacks_from_x_unpacks(y_lengths, x_unpacks);
|
||||
return y_unpacks;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access() const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
return u.get_num_of_access() *
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
}
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
sweep_tile_uspan(
|
||||
spans[number<I>{}],
|
||||
[&](auto... i_idx) {
|
||||
constexpr auto next_span_idx = make_tuple(make_tuple(i_idx)...);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx);
|
||||
},
|
||||
get_y_unpacks());
|
||||
}
|
||||
template <typename F, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(const F& f, number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto u =
|
||||
static_uford<typename decltype(spans[number<I>{}])::Impl, decltype(get_y_unpacks())>{};
|
||||
constexpr auto access_stride =
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}
|
||||
.get_num_of_access();
|
||||
constexpr auto curr_i_access = number<i_access / access_stride>{};
|
||||
constexpr auto next_i_access = number<i_access % access_stride>{};
|
||||
u(
|
||||
[&](auto... i_idx) {
|
||||
constexpr auto next_span_idx =
|
||||
make_tuple(make_tuple(detail::make_tile_distributed_index(i_idx))...);
|
||||
sweep_tile_impl<DistributedTensor, UnpacksPerXDim, sequence<Is...>>{}(
|
||||
f, next_span_idx, next_i_access);
|
||||
},
|
||||
curr_i_access);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
/*
|
||||
* Enhanced sweep-tile utility, can control unpacks along each X-dim
|
||||
* the lambda function argument is the distributed-idx, which can directly
|
||||
* plugged into the distributed tensor as setter/getter
|
||||
*
|
||||
* e.g. below function, y with the type DistributedTensor, r is row scale
|
||||
*
|
||||
* // sweep tile 1 by 1
|
||||
* sweep_tile<DistributedTensor>([&](auto idx) {
|
||||
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
* y(idx) = y(idx) * r(row_id);
|
||||
* });
|
||||
*
|
||||
* // sweep tile with 2 pixel from last dim each function call
|
||||
* sweep_tile<DistributedTensor>(
|
||||
* [&](auto idx_0, auto idx_1) {
|
||||
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
|
||||
* y(idx_0) = y(idx_0) * r(row_id);
|
||||
* y(idx_1) = y(idx_1) * r(row_id);
|
||||
* },
|
||||
* sequence<1, 2>{});
|
||||
*
|
||||
* // sweep tile with 2x2 pixel each function call
|
||||
* sweep_tile<DistributedTensor>(
|
||||
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
|
||||
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
|
||||
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
|
||||
* y(idx_00) = y(idx_00) * r(row_id0);
|
||||
* y(idx_01) = y(idx_01) * r(row_id0);
|
||||
* y(idx_10) = y(idx_10) * r(row_id1);
|
||||
* y(idx_11) = y(idx_11) * r(row_id1);
|
||||
* },
|
||||
* sequence<2, 2>{});
|
||||
*
|
||||
* TODO: do we need constexpr? lambda function could be non-constexpr
|
||||
*/
|
||||
template <typename DistributedTensor,
|
||||
typename F,
|
||||
typename UnpacksPerXDim =
|
||||
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F& f, UnpacksPerXDim = {})
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(f);
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename F,
|
||||
typename UnpacksPerXDim =
|
||||
typename uniform_sequence_gen<DistributedTensor::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
sweep_tile(const DistributedTensor&, const F& f, UnpacksPerXDim = {})
|
||||
{
|
||||
sweep_tile<DistributedTensor, F, UnpacksPerXDim>(f, UnpacksPerXDim{});
|
||||
}
|
||||
|
||||
/*
|
||||
* construct a sweep tile instance, which support issue the lambda one by one
|
||||
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
|
||||
* the functionality is the same as sweep_tile()
|
||||
*/
|
||||
template <typename DistributedTensor_,
|
||||
typename F_,
|
||||
typename UnpacksPerXDim_ =
|
||||
typename uniform_sequence_gen<DistributedTensor_::get_num_of_dimension(), 1>::type>
|
||||
struct tile_sweeper
|
||||
{
|
||||
using DistributedTensor = remove_cvref_t<DistributedTensor_>;
|
||||
using F = remove_cvref_t<F_>;
|
||||
using UnpacksPerXDim = remove_cvref_t<UnpacksPerXDim_>;
|
||||
|
||||
CK_TILE_HOST_DEVICE tile_sweeper(const F& f_, UnpacksPerXDim = {}) : f(f_) {}
|
||||
CK_TILE_HOST_DEVICE tile_sweeper(const DistributedTensor&, const F& f_, UnpacksPerXDim = {})
|
||||
: f(f_)
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
constexpr auto tmp =
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{};
|
||||
return tmp.get_num_of_access();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void operator()() const
|
||||
{
|
||||
sweep_tile<DistributedTensor>(f, UnpacksPerXDim{});
|
||||
}
|
||||
|
||||
template <index_t i_access>
|
||||
CK_TILE_HOST_DEVICE void operator()(number<i_access>) const
|
||||
{
|
||||
constexpr auto spans = DistributedTensor::get_distributed_spans();
|
||||
|
||||
impl::sweep_tile_impl_0<DistributedTensor,
|
||||
UnpacksPerXDim,
|
||||
typename arithmetic_sequence_gen<0, spans.size(), 1>::type>{}(
|
||||
f, number<i_access>{});
|
||||
}
|
||||
F f;
|
||||
};
|
||||
|
||||
// partial deduction is not allowed
|
||||
// template <typename T, typename F, typename U>
|
||||
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
|
||||
// deduction guide
|
||||
template <typename T,
|
||||
typename F,
|
||||
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
|
||||
} // namespace ck_tile
|
||||
945
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
945
include/ck_tile/core/tensor/tensor_adaptor.hpp
Normal file
@@ -0,0 +1,945 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// BottomDimensionHiddenIds : Sequence<...>
|
||||
// TopDimensionHiddenIds : Sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform()
|
||||
{
|
||||
return Transforms::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return LowerDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return UpperDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids()
|
||||
{
|
||||
return BottomDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return TopDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_top) {
|
||||
constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<idim_hidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of number and index_t
|
||||
return container_reduce(lengths, multiplies{}, number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
{
|
||||
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
|
||||
// saved in transformation
|
||||
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == IDimHidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension()
|
||||
{
|
||||
return BottomDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension()
|
||||
{
|
||||
return TopDimensionHiddenIds::size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
less<index_t>,
|
||||
equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::size();
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension();
|
||||
constexpr static index_t ndim_top_ = get_num_of_top_dimension();
|
||||
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
// may be index_t or number<>
|
||||
using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
|
||||
: transforms_{transforms}, element_size_{initialize_element_size(transforms)}
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
|
||||
|
||||
// FIXME: this logic is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number<IDimHidden>) const
|
||||
{
|
||||
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
|
||||
|
||||
constexpr auto tmp = get_transform_and_its_upper_dimension(number<IDimHidden>{});
|
||||
|
||||
constexpr index_t itran = tmp[number<0>{}];
|
||||
constexpr index_t idim_up = tmp[number<1>{}];
|
||||
constexpr bool found = tmp[number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
|
||||
}
|
||||
|
||||
template <index_t IDimTop>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number<IDimTop> idim_top) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
template <index_t IDimBottom>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t
|
||||
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
|
||||
{
|
||||
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
|
||||
number<ndim_top_>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
|
||||
number<ndim_bottom_>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TopIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
|
||||
{
|
||||
static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize uppest index
|
||||
set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = get_transforms().at(itran);
|
||||
constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::size(), 1>{}([&](auto i) {
|
||||
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
|
||||
});
|
||||
|
||||
return is_known && ck_tile::is_known_at_compile_time<ElementSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
|
||||
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
|
||||
{
|
||||
auto vector_lengths = guaranteed_vector_lengths;
|
||||
auto vector_strides = guaranteed_vector_strides;
|
||||
|
||||
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto up_guaranteed_vector_lengths =
|
||||
get_container_subset(guaranteed_vector_lengths, up_dims);
|
||||
const auto up_guaranteed_vector_strides =
|
||||
get_container_subset(guaranteed_vector_strides, up_dims);
|
||||
|
||||
// only need type of transform
|
||||
auto [up_vector_lengths, up_vector_strides] =
|
||||
Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
|
||||
get_container_subset(vector_lengths, low_dims),
|
||||
get_container_subset(vector_strides, low_dims));
|
||||
|
||||
if constexpr(up_dims.size() > 0)
|
||||
{
|
||||
for(index_t i = 0; i < up_dims.size(); ++i)
|
||||
{
|
||||
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
|
||||
? up_guaranteed_vector_lengths[i]
|
||||
: up_vector_lengths[i];
|
||||
|
||||
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
|
||||
? up_guaranteed_vector_strides[i]
|
||||
: up_vector_strides[i];
|
||||
}
|
||||
}
|
||||
|
||||
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
|
||||
set_container_subset(vector_strides, up_dims, up_vector_strides);
|
||||
});
|
||||
|
||||
constexpr auto top_dims = TopDimensionHiddenIds{};
|
||||
|
||||
return make_tuple(get_container_subset(vector_lengths, top_dims),
|
||||
get_container_subset(vector_strides, top_dims));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_adaptor{");
|
||||
|
||||
//
|
||||
printf("transforms: ");
|
||||
print(transforms_);
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("LowerDimensionHiddenIds: ");
|
||||
print(LowerDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("UpperDimensionHiddenIds: ");
|
||||
print(UpperDimensionHiddenIdss{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("BottomDimensionHiddenIds: ");
|
||||
print(BottomDimensionHiddenIds{});
|
||||
printf(", ");
|
||||
|
||||
//
|
||||
printf("TopDimensionHiddenIds: ");
|
||||
print(TopDimensionHiddenIds{});
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
|
||||
LowerDimensionOldTopIdss,
|
||||
UpperDimensionNewTopIdss)
|
||||
{
|
||||
constexpr index_t ntransform = Transforms::size();
|
||||
|
||||
static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
|
||||
UpperDimensionNewTopIdss::size() == ntransform,
|
||||
"wrong!");
|
||||
|
||||
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
|
||||
constexpr auto all_low_dim_old_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_up_dim_new_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
|
||||
|
||||
// low_dim_hidden_idss
|
||||
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
|
||||
|
||||
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
|
||||
number<ntransform>{});
|
||||
|
||||
// bottom_dim_hidden_ids
|
||||
constexpr auto bottom_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
|
||||
|
||||
// top_dim_hidden_ids
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number<ndim_old_top>{};
|
||||
|
||||
return tensor_adaptor<remove_cvref_t<Transforms>,
|
||||
remove_cvref_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor, and to put it outside the scope where it is used
|
||||
// (transform_tensor_adaptor) because template cannot be defined inside a function
|
||||
// template
|
||||
template <typename NewTransforms>
|
||||
struct lambda_get_up_dim_num
|
||||
{
|
||||
template <typename I>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
|
||||
{
|
||||
using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
|
||||
return number<Tran::get_num_of_upper_dimension()>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OldTensorAdaptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
// sanity check
|
||||
{
|
||||
static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
|
||||
NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
|
||||
"wrong! inconsitent number of transform");
|
||||
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
constexpr auto low_dim_hidden_idss = transform_tuples(
|
||||
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
|
||||
[](auto low_dim_top_ids) constexpr {
|
||||
return transform_sequences(
|
||||
// convert lower dimension top id to hidden id
|
||||
[](auto low_dim_top_id) constexpr {
|
||||
return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
|
||||
},
|
||||
low_dim_top_ids);
|
||||
},
|
||||
NewLowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr index_t num_new_transform = NewTransforms::size();
|
||||
|
||||
// upper dimension's hidden idss
|
||||
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
|
||||
|
||||
constexpr auto up_dim_numbers =
|
||||
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
|
||||
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
1>::type{};
|
||||
},
|
||||
number<num_new_transform>{});
|
||||
|
||||
// new top dimension's hidden ids
|
||||
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_top_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr auto new_top_dim_hidden_ids =
|
||||
unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
|
||||
|
||||
// put everything together
|
||||
const auto all_transforms =
|
||||
container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
|
||||
|
||||
return tensor_adaptor<
|
||||
remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename TensorAdaptor0, typename TensorAdaptor1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
|
||||
const TensorAdaptor1& adaptor1)
|
||||
{
|
||||
static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
|
||||
TensorAdaptor1::get_num_of_bottom_dimension(),
|
||||
"wrong!");
|
||||
|
||||
// all_transforms = transform0 + transform1
|
||||
const auto all_transforms =
|
||||
container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
|
||||
|
||||
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor0_max_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
|
||||
|
||||
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
|
||||
|
||||
// get the min of all lower dimenions, but not bottom dimension (because their id will
|
||||
// be matched with top id from adaptor0)
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t low_dim_hidden_id =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
|
||||
|
||||
bool is_bottom_dim = false;
|
||||
static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
|
||||
if constexpr(low_dim_hidden_id ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
|
||||
{
|
||||
is_bottom_dim = true;
|
||||
}
|
||||
});
|
||||
|
||||
if(!is_bottom_dim)
|
||||
{
|
||||
adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
|
||||
|
||||
// get the min of all upper dimensions
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor1_min_hidden_id_ =
|
||||
min(adaptor1_min_hidden_id_,
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor1_min_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_hidden_id_shift =
|
||||
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
|
||||
|
||||
constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
|
||||
|
||||
// all_low_dim_hidden_idss =
|
||||
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
|
||||
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_low_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto low_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
// match hidden id
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
|
||||
// if this low dim is bottom dim, then do id matching
|
||||
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
|
||||
TensorAdaptor1::get_bottom_dimension_hidden_ids()
|
||||
[idim_bottom_1])
|
||||
{
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) =
|
||||
TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_low_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
|
||||
|
||||
// all_up_dim_hidden_idss =
|
||||
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
|
||||
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_up_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
|
||||
|
||||
constexpr auto up_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
|
||||
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
|
||||
number<ndim_up_1>{});
|
||||
},
|
||||
number<TensorAdaptor1::get_num_of_transform()>{});
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
|
||||
|
||||
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
|
||||
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
|
||||
|
||||
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
|
||||
|
||||
// put everything together
|
||||
return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
|
||||
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Xs,
|
||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// Macro function
|
||||
// construct constexpr tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (coord_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
static_assert(name == coord_transform_enum::pass_through || \
|
||||
name == coord_transform_enum::pad || \
|
||||
name == coord_transform_enum::embed || \
|
||||
name == coord_transform_enum::merge || \
|
||||
name == coord_transform_enum::unmerge || \
|
||||
name == coord_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == coord_transform_enum::pass_through) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pass_through_transform(low_len); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::pad) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
auto left_pad = meta_data.template pop<index_t>(pos); \
|
||||
auto right_pad = meta_data.template pop<index_t>(pos); \
|
||||
\
|
||||
return make_pad_transform(low_len, left_pad, right_pad); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::embed) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
auto coefficients = \
|
||||
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_embed_transform(up_lens, coefficients); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::merge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
|
||||
\
|
||||
return make_merge_transform(low_lens); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::unmerge) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_unmerge_transform(up_lens); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::replicate) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
|
||||
\
|
||||
return make_replicate_transform(up_lens); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
|
||||
// Macro function
|
||||
// construct static tensor_adaptor from constexpr encoding
|
||||
// encoded_tensor_adaptor are Tuple of following objects:
|
||||
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
|
||||
// 1.1 name (coord_transform_enum)
|
||||
// 1.2 meta data for constructor of the transform
|
||||
// 1.3 num of lower dimension (index_t)
|
||||
// 1.4 lower dimension Ids (array of fixed size)
|
||||
// 1.5 num of up dimension (index_t)
|
||||
// 1.6 upper dimension Ids (array of fixed size)
|
||||
// 2. num of transforms (index_t)
|
||||
// 3. encoded bottom dimension Ids (array of fixed size)
|
||||
// 4. num of bottom dimension (index_t)
|
||||
// 5. encoded top dimension Ids (array of fixed size)
|
||||
// 6. num of top dimension (index_t)
|
||||
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
|
||||
[encoded_tensor_adaptor]() { \
|
||||
using namespace ck_tile; \
|
||||
\
|
||||
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
|
||||
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
|
||||
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
|
||||
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
static_assert(name == coord_transform_enum::pass_through || \
|
||||
name == coord_transform_enum::pad || \
|
||||
name == coord_transform_enum::embed || \
|
||||
name == coord_transform_enum::merge || \
|
||||
name == coord_transform_enum::unmerge || \
|
||||
name == coord_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == coord_transform_enum::pass_through) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
return make_pass_through_transform(number<low_len>{}); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::pad) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
constexpr index_t left_pad = \
|
||||
meta_data.template get<index_t>(sizeof(low_len)); \
|
||||
\
|
||||
constexpr index_t right_pad = \
|
||||
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
|
||||
\
|
||||
return make_pad_transform( \
|
||||
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::embed) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
constexpr auto coefficients = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
|
||||
\
|
||||
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
|
||||
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::merge) \
|
||||
{ \
|
||||
constexpr auto low_lens = \
|
||||
meta_data.template get<array<index_t, num_low_dim>>(0); \
|
||||
\
|
||||
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::unmerge) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
else if constexpr(name == coord_transform_enum::replicate) \
|
||||
{ \
|
||||
constexpr auto up_lens = \
|
||||
meta_data.template get<array<index_t, num_up_dim>>(0); \
|
||||
\
|
||||
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
|
||||
} \
|
||||
}, \
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
|
||||
\
|
||||
return TO_SEQUENCE(low_dims, num_low_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
|
||||
\
|
||||
return TO_SEQUENCE(up_dims, num_up_dim); \
|
||||
}, \
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
|
||||
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
|
||||
\
|
||||
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
|
||||
remove_cvref_t<decltype(low_dim_idss)>, \
|
||||
remove_cvref_t<decltype(up_dim_idss)>, \
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>, \
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
|
||||
}()
|
||||
257
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
257
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
|
||||
struct tensor_adaptor_coordinate
|
||||
{
|
||||
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using BottomIndex = multi_index<ndim_bottom_>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: idx_hidden_{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
|
||||
|
||||
//
|
||||
HiddenIndex idx_hidden_;
|
||||
};
|
||||
|
||||
template <typename Adaptor, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
|
||||
constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
|
||||
|
||||
multi_index<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize visible index
|
||||
set_container_subset(idx_hidden, top_dim_ids, idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
|
||||
auto itran = itran_p1 - number<1>{};
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_low;
|
||||
|
||||
tran.calculate_lower_index(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return tensor_adaptor_coordinate<ndim_hidden,
|
||||
remove_cvref_t<decltype(bottom_dim_ids)>,
|
||||
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true,
|
||||
typename Adaptor,
|
||||
typename AdaptorCoord,
|
||||
typename TopIndex,
|
||||
typename BottomIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top,
|
||||
BottomIndex& idx_diff_bottom)
|
||||
{
|
||||
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
|
||||
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
// static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
|
||||
|
||||
// judge whether calculation of lower diff is needed for each transform
|
||||
// use index_t for boolean type
|
||||
auto do_transforms = make_zero_multi_index<ntransform>();
|
||||
|
||||
if constexpr(JudgeDoTransforms)
|
||||
{
|
||||
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// decide do_transform by checkout non-zero index diff components
|
||||
multi_index<ndim_top> non_zero_diff_pick_top;
|
||||
|
||||
static_for<0, ndim_top, 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
|
||||
|
||||
set_container_subset(
|
||||
is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> non_zero_diff_pick_low;
|
||||
|
||||
// if any of upper index diff components is non-zero, then
|
||||
// 1) Need to do this transform
|
||||
// 2) all components of lower index diff will assume to be non-zero and need to be
|
||||
// computed
|
||||
const bool idx_diff_up_has_non_zero = container_reduce(
|
||||
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
|
||||
|
||||
do_transforms(itran) = idx_diff_up_has_non_zero;
|
||||
|
||||
static_for<0, dims_low.size(), 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
|
||||
|
||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
|
||||
}
|
||||
|
||||
// this is what needs to be calculated
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// initialize top index diff
|
||||
set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
|
||||
|
||||
// this is what needs to be updated
|
||||
auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
// update top index
|
||||
auto idx_hidden_pick_top =
|
||||
get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
|
||||
|
||||
idx_hidden_pick_top += idx_diff_top;
|
||||
|
||||
set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
|
||||
|
||||
// update rest of hidden index
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
if(do_transforms[itran])
|
||||
{
|
||||
const auto& tran = adaptor.get_transforms().at(itran);
|
||||
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
|
||||
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
|
||||
|
||||
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
|
||||
auto idx_low = get_container_subset(idx_hidden, dims_low);
|
||||
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
|
||||
|
||||
multi_index<dims_low.size()> idx_diff_low;
|
||||
|
||||
tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
|
||||
|
||||
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
}
|
||||
});
|
||||
|
||||
// set bottom index diff
|
||||
idx_diff_bottom =
|
||||
get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
|
||||
AdaptorCoord& coord,
|
||||
const TopIndex& idx_diff_top)
|
||||
{
|
||||
constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
|
||||
|
||||
multi_index<ndim_bottom> tmp;
|
||||
|
||||
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdaptorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
|
||||
const AdaptorCoord& coord)
|
||||
{
|
||||
bool valid = true;
|
||||
|
||||
constexpr index_t ntransform = Adaptor::get_num_of_transform();
|
||||
|
||||
const auto& idx_hidden = coord.get_hidden_index();
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
|
||||
const auto tran = adaptor.get_transforms().at(itran);
|
||||
|
||||
// check validity, only if current transformation does not always has a valid mapping
|
||||
if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
|
||||
{
|
||||
const auto idx_up = get_container_subset(
|
||||
idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
|
||||
|
||||
// Comment: using valid = valid && .. will result in weird control flow in ISA
|
||||
valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
|
||||
}
|
||||
});
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
template <typename Adaptor, typename AdpatorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
|
||||
const AdpatorCoord& coord)
|
||||
{
|
||||
// check top index
|
||||
const auto& idx_top = coord.get_top_index();
|
||||
|
||||
bool is_top_index_valid = true;
|
||||
|
||||
static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
|
||||
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
|
||||
is_top_index_valid =
|
||||
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
|
||||
});
|
||||
|
||||
// check other hidden index
|
||||
return is_top_index_valid &&
|
||||
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
92
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
92
include/ck_tile/core/tensor/tensor_coordinate.hpp
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimHidden, typename TopDimensionHiddenIds>
|
||||
struct tensor_coordinate
|
||||
: public tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>;
|
||||
|
||||
// TODO make these private
|
||||
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
|
||||
|
||||
using HiddenIndex = multi_index<NDimHidden>;
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden)
|
||||
: Base{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
// construct from TensorAdaptorCoordinte base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_offset() const
|
||||
{
|
||||
return Base::get_bottom_index()[number<0>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const
|
||||
{
|
||||
return Base::get_hidden_index();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); }
|
||||
};
|
||||
|
||||
template <typename TensorDesc, typename TopIndex>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
const TopIndex& idx_top)
|
||||
{
|
||||
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
|
||||
|
||||
return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
|
||||
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
|
||||
adaptor_coord};
|
||||
}
|
||||
|
||||
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
|
||||
{
|
||||
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
return adaptor_coordinate_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
467
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
467
include/ck_tile/core/tensor/tensor_descriptor.hpp
Normal file
@@ -0,0 +1,467 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
|
||||
// TopDimensionHiddenIds> : sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename TopDimensionHiddenIds,
|
||||
typename ElementSpaceSize,
|
||||
typename GuaranteedVectorLengths_,
|
||||
typename GuaranteedVectorSrides_>
|
||||
struct tensor_descriptor : public tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>
|
||||
{
|
||||
using Base = tensor_adaptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
sequence<0>,
|
||||
TopDimensionHiddenIds>;
|
||||
|
||||
using ElementSpaceSizeType = ElementSpaceSize;
|
||||
|
||||
constexpr static index_t ntransform_ = Base::get_num_of_transform();
|
||||
constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension();
|
||||
constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension();
|
||||
|
||||
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
|
||||
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
|
||||
|
||||
static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ &&
|
||||
GuaranteedVectorStrides::size() == ndim_hidden_,
|
||||
"wrong! inconsistent # of hidden dimensions");
|
||||
|
||||
using TopIndex = multi_index<ndim_top_>;
|
||||
using HiddenIndex = multi_index<ndim_hidden_>;
|
||||
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{transforms}, element_space_size_{element_space_size}
|
||||
|
||||
{
|
||||
static_assert(Transforms::size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
// construct from tensor_adaptor base class
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor,
|
||||
ElementSpaceSize element_space_size)
|
||||
: Base{adaptor}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return Base::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_length(number<IDim> idim) const
|
||||
{
|
||||
return Base::get_top_dimension_length(idim);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
|
||||
{
|
||||
return Base::get_top_dimension_lengths();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
|
||||
{
|
||||
return element_space_size_;
|
||||
}
|
||||
|
||||
template <typename Idx>
|
||||
CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const
|
||||
{
|
||||
return Base::calculate_bottom_index(idx)[number<0>{}];
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const
|
||||
{
|
||||
return Base::get_transforms();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_lower_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
|
||||
{
|
||||
return Base::get_upper_dimension_hidden_idss();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
|
||||
{
|
||||
return Base::get_top_dimension_hidden_ids();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return Base::is_known_at_compile_time() &&
|
||||
ck_tile::is_known_at_compile_time<ElementSpaceSize>::value;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
|
||||
{
|
||||
return Base::get_top_dimension_safe_vector_length_strides(
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
|
||||
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_descriptor{");
|
||||
|
||||
// tensor_adaptor
|
||||
Base::print();
|
||||
printf(", ");
|
||||
|
||||
// element_space_size_
|
||||
printf("element_space_size_: ");
|
||||
print(element_space_size_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <typename Adaptor, typename ElementSpaceSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
|
||||
const ElementSpaceSize& element_space_size)
|
||||
{
|
||||
constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension();
|
||||
|
||||
return tensor_descriptor<remove_cvref_t<decltype(adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type,
|
||||
typename uniform_sequence_gen<NDimHidden, -1>::type>{
|
||||
adaptor, element_space_size};
|
||||
}
|
||||
|
||||
template <typename OldTensorDescriptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldTopIdss,
|
||||
typename NewUpperDimensionNewTopIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldTopIdss,
|
||||
NewUpperDimensionNewTopIdss)
|
||||
{
|
||||
const auto element_space_size = old_tensor_desc.get_element_space_size();
|
||||
|
||||
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldTopIdss{},
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension();
|
||||
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension();
|
||||
|
||||
using NewGuaranteedVectorLengths = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorLengths,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
using NewGuaranteedVectorStrides = typename sequence_merge<
|
||||
typename OldTensorDescriptor::GuaranteedVectorStrides,
|
||||
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
|
||||
|
||||
return tensor_descriptor<
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_transforms())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_lower_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_upper_dimension_hidden_idss())>,
|
||||
remove_cvref_t<decltype(new_tensor_adaptor.get_top_dimension_hidden_ids())>,
|
||||
remove_cvref_t<decltype(element_space_size)>,
|
||||
NewGuaranteedVectorLengths,
|
||||
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Lengths, typename Strides, index_t I, typename AccOld>
|
||||
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
|
||||
const Strides& strides,
|
||||
number<I> i,
|
||||
AccOld acc_old)
|
||||
{
|
||||
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
|
||||
|
||||
if constexpr(i.value < Lengths::size() - 1)
|
||||
{
|
||||
return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return acc_new;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/*
|
||||
* These functions create naive tensor descriptor
|
||||
*/
|
||||
|
||||
// Lengths..., Strides... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size =
|
||||
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
// tensor descriptor with offset, the offset will not be added into element space size
|
||||
// only have an information of the starting offset, and will impact on offset calculation
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
const offset& os,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = detail::calculate_element_space_size_impl(
|
||||
lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorStride>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_embed_transform(lengths, strides)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) long_number<>
|
||||
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_unmerge_transform(lengths));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename Offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
|
||||
const tuple<Lengths...>& lengths,
|
||||
const Offset& offset,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = sequence<1>{};
|
||||
|
||||
using GuaranteedVectorLengths =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
|
||||
sequence<GuaranteedLastDimensionVectorLength>>::type;
|
||||
|
||||
using GuaranteedVectorStrides =
|
||||
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type, sequence<1>>::type;
|
||||
|
||||
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>{transforms, element_space_size};
|
||||
}();
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_unmerge_transform(lengths)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) number<>, which is known at compile-time
|
||||
// align could be:
|
||||
// 1) index_t, or
|
||||
// 2) number<>
|
||||
template <typename... Lengths, typename Align>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align align)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto stride_n_minus_2 = integer_least_multiple(lengths[number<N - 1>{}], align);
|
||||
|
||||
auto strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == N - 1)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else if constexpr(i.value == N - 2)
|
||||
{
|
||||
return number<stride_n_minus_2>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(
|
||||
lengths, multiplies{}, number<stride_n_minus_2>{}, i + I1, number<N - 1>{}, I1);
|
||||
}
|
||||
},
|
||||
number<N>{});
|
||||
|
||||
return make_naive_tensor_descriptor(lengths, strides);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
533
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
533
include/ck_tile/core/tensor/tensor_view.hpp
Normal file
@@ -0,0 +1,533 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
* tensor_view
|
||||
* abstract the underneath memory buffer(global, LDS, etc...)
|
||||
* and provide a unified get/set function for access
|
||||
*
|
||||
* For addressing into the buffer we use 2 variable to control:
|
||||
* coord : ND tensor coordinate, will calculate the actual offset inside
|
||||
* linear_offset : 1D offset, will be used in the immediate field of
|
||||
* the buffer instruction to help reduce register usage
|
||||
*
|
||||
* User can use either of the field, or both to indexing into the tensor
|
||||
*
|
||||
* We usually provide 2 set of API for buffer get/set, e.g.
|
||||
* get_vectorized_elements()/get_vectorized_elements_raw()
|
||||
* the former usually will call intrinsic or normal C function, the later
|
||||
* usually will call inline-asm function
|
||||
*
|
||||
*/
|
||||
template <typename BufferView_,
|
||||
typename TensorDesc_,
|
||||
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
|
||||
struct tensor_view
|
||||
{
|
||||
using buffer_view = remove_reference_t<BufferView_>;
|
||||
using DataType = typename buffer_view::type;
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
static constexpr auto DstInMemOp = DstInMemOp_;
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view,
|
||||
const TensorDesc& desc)
|
||||
: buf_{buffer_view}, desc_{desc}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
|
||||
{
|
||||
return TensorDesc::get_num_of_top_dimension();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element, // flag
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(dst,
|
||||
coord.get_offset() /
|
||||
PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
// buf_
|
||||
printf("buf_: ");
|
||||
print(buf_);
|
||||
printf(", ");
|
||||
|
||||
// desc_
|
||||
printf("desc_: ");
|
||||
print(desc_);
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
// member
|
||||
buffer_view buf_;
|
||||
TensorDesc desc_;
|
||||
};
|
||||
|
||||
// placeholder type if we want to opt-out a tile view parameter
|
||||
struct null_tensor_view
|
||||
{
|
||||
};
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
const tensor_descriptor<Ts...>& desc)
|
||||
{
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
index_t GuaranteedLastDimensionVectorStride = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view(DataType* p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
auto desc = make_naive_tensor_descriptor(lengths,
|
||||
strides,
|
||||
number<GuaranteedLastDimensionVectorLength>{},
|
||||
number<GuaranteedLastDimensionVectorStride>{});
|
||||
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_view_packed(DataType* p,
|
||||
const tuple<Lengths...>& lengths,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
|
||||
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <typename OldTensorView,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
typename NewUpperDimensionNewVisibleIdss>
|
||||
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
{
|
||||
auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
|
||||
new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss{},
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
return tensor_view<typename OldTensorView::buffer_view,
|
||||
remove_cvref_t<decltype(new_desc)>,
|
||||
remove_cvref_t<OldTensorView>::DstInMemOp>{old_tensor_view.buf_, new_desc};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename TileLengths, // tuple<...>
|
||||
typename DoPads> // sequence<bool, bool, ...>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
constexpr index_t num_dim = DoPads::size();
|
||||
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
|
||||
"wrong! inconsistent # of dimensions");
|
||||
|
||||
// transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto idim) {
|
||||
const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim);
|
||||
|
||||
const auto tile_length = tile_lengths[idim];
|
||||
|
||||
const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length;
|
||||
|
||||
const auto pad_length = new_length - old_length;
|
||||
|
||||
constexpr bool DoPad = DoPads::at(idim);
|
||||
|
||||
const auto transform =
|
||||
conditional_expr<DoPad>(make_right_pad_transform(old_length, pad_length),
|
||||
make_pass_through_transform(old_length));
|
||||
|
||||
return transform;
|
||||
},
|
||||
number<num_dim>{});
|
||||
|
||||
// lower dimension Id
|
||||
const auto lower_dimss =
|
||||
generate_tuple([&](auto idim) { return sequence<idim.value>{}; }, number<num_dim>{});
|
||||
|
||||
// upper dimension Id
|
||||
const auto upper_dimss = lower_dimss;
|
||||
|
||||
return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
673
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
673
include/ck_tile/core/tensor/tile_distribution.hpp
Normal file
@@ -0,0 +1,673 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
return Distribution::_get_partition_index();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// distributed span
|
||||
template <index_t... PartialHsLengths>
|
||||
struct tile_distributed_span
|
||||
{
|
||||
using Impl = sequence<PartialHsLengths...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
// distributed index
|
||||
template <index_t... PartialHsIndices>
|
||||
struct tile_distributed_index
|
||||
{
|
||||
using Impl = sequence<PartialHsIndices...>;
|
||||
|
||||
static constexpr auto impl_ = Impl{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_span<Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence<Is...>)
|
||||
{
|
||||
return tile_distributed_index<Is...>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename PsYs2XsAdaptor_,
|
||||
typename Ys2DDescriptor_,
|
||||
typename StaticTileDistributionEncoding_,
|
||||
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
|
||||
// should be more elegnat
|
||||
struct tile_distribution
|
||||
{
|
||||
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
|
||||
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
|
||||
|
||||
static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
|
||||
"wrong! should be static");
|
||||
|
||||
static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
|
||||
static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
|
||||
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
|
||||
|
||||
PsYs2XsAdaptor ps_ys_to_xs_;
|
||||
Ys2DDescriptor ys_to_d_;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto _get_partition_index()
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
return array<index_t, 1>{get_lane_id()};
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
return array<index_t, 2>{get_warp_id(), get_lane_id()};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
|
||||
{
|
||||
#if 0
|
||||
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
|
||||
ps_ys_to_xs_.GetBottomDimensionLengths();
|
||||
#else
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t x_length =
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
|
||||
|
||||
return number<x_length>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
|
||||
{
|
||||
return ps_ys_to_xs_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding()
|
||||
{
|
||||
return DstrEncode{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
// Calculate Replication index [R0, R1, ...] based on Partion index
|
||||
// FIXME: very nasty implementation
|
||||
template <typename PartitionIndex>
|
||||
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
|
||||
{
|
||||
static_assert(PartitionIndex::size() == NDimP, "wrong!");
|
||||
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
|
||||
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
|
||||
|
||||
array<index_t, NDimR> rs_idx;
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
|
||||
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
|
||||
|
||||
// 0-th rh_major is the replicate dimension
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
constexpr index_t adaptor_hidden_id =
|
||||
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
|
||||
|
||||
// fill in
|
||||
rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rs_idx;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename PartitionIndex = decltype(_get_partition_index())>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
|
||||
{
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
const auto window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
|
||||
return window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
|
||||
constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto span_impl = distributed_spans_impl[i];
|
||||
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
|
||||
|
||||
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
|
||||
|
||||
return detail::make_tile_distributed_span(span);
|
||||
},
|
||||
number<NDimX>{});
|
||||
}
|
||||
|
||||
// FIXME: it's hacky to get Y index from Distributed-Index
|
||||
template <typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
{
|
||||
constexpr auto ys_idx_arr = [] {
|
||||
array<index_t, NDimY> ys_idx;
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
|
||||
constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
|
||||
|
||||
constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
|
||||
|
||||
ys_idx(i) = dstr_index.impl_[span_minor];
|
||||
});
|
||||
|
||||
return ys_idx;
|
||||
}();
|
||||
|
||||
constexpr index_t ndim_y = NDimY;
|
||||
|
||||
return TO_SEQUENCE(ys_idx_arr, ndim_y);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static()
|
||||
{
|
||||
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution{");
|
||||
//
|
||||
printf("tile_distribution_encoding: ");
|
||||
print(DstrEncode{});
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_ys_to_xs_: ");
|
||||
print(ps_ys_to_xs_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_d_: ");
|
||||
print(ys_to_d_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t NDimMax>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
|
||||
{
|
||||
array<index_t, NDimMax> arr{0};
|
||||
|
||||
for(index_t i = 0; i < iend - ibegin; ++i)
|
||||
{
|
||||
arr(i) = ibegin + i;
|
||||
}
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
// this returns a constexpr encoding of tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
|
||||
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
|
||||
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
|
||||
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
|
||||
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
|
||||
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
|
||||
|
||||
// FIXME: increase max value if fail
|
||||
constexpr index_t kMaxNumTransforms = 20;
|
||||
constexpr index_t kMaxMetaDataSize = 128;
|
||||
constexpr index_t kMaxNumDim = 10;
|
||||
|
||||
using Name = coord_transform_enum;
|
||||
using MetaData = meta_data_buffer<kMaxMetaDataSize>;
|
||||
using NumDim = index_t;
|
||||
using Dims = array<index_t, kMaxNumDim>;
|
||||
using Lengths = array<index_t, kMaxNumDim>;
|
||||
|
||||
// Tile Adaptor
|
||||
// bottom dims [x0, x1, x2, ...]
|
||||
// top dims [p0, p1, ..., y0, y1, ...]
|
||||
constexpr index_t ndim_x = HsLengthss::size();
|
||||
|
||||
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
|
||||
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
|
||||
|
||||
auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
|
||||
|
||||
index_t num_tran = 0;
|
||||
index_t hidden_dim_cnt = ndim_x;
|
||||
|
||||
// this is replicate transform
|
||||
{
|
||||
constexpr index_t ndim_r_minor = RsLengths::size();
|
||||
|
||||
constexpr auto r_minor_lengths = RsLengths{};
|
||||
|
||||
trans(num_tran++) = {
|
||||
coord_transform_enum::replicate,
|
||||
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
|
||||
NumDim{0},
|
||||
Dims{},
|
||||
NumDim{ndim_r_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_r_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
};
|
||||
|
||||
// these are Unmerge transforms for X dimesions
|
||||
static_for<0, ndim_x, 1>{}([&trans,
|
||||
&num_tran,
|
||||
&hidden_dim_cnt,
|
||||
&rh_major_minor_to_hidden_ids,
|
||||
&rh_major_minor_to_hidden_lengths](auto idim_x) {
|
||||
// typename HsLengthss::base{}.foo();
|
||||
constexpr auto h_minor_lengths =
|
||||
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
|
||||
|
||||
constexpr index_t ndim_h_minor = h_minor_lengths.size();
|
||||
|
||||
trans(num_tran++) = {
|
||||
coord_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
|
||||
NumDim{1},
|
||||
Dims{idim_x},
|
||||
NumDim{ndim_h_minor},
|
||||
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
|
||||
|
||||
for(index_t i = 0; i < ndim_h_minor; ++i)
|
||||
{
|
||||
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
|
||||
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
|
||||
|
||||
hidden_dim_cnt++;
|
||||
}
|
||||
});
|
||||
|
||||
// transform: P dimensions
|
||||
constexpr index_t ndim_p = Ps2RHssMajor::size();
|
||||
|
||||
Dims hidden_dim_id_ps;
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
|
||||
//
|
||||
index_t hidden_dim_id_p = hidden_dim_cnt++;
|
||||
|
||||
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
|
||||
|
||||
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
|
||||
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
|
||||
|
||||
static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
|
||||
|
||||
constexpr index_t ndim_low = p2RHsMajor.size();
|
||||
|
||||
Dims low_dims;
|
||||
Lengths low_lengths;
|
||||
|
||||
for(index_t i = 0; i < ndim_low; ++i)
|
||||
{
|
||||
index_t rh_major = p2RHsMajor[i];
|
||||
index_t rh_minor = p2RHsMinor[i];
|
||||
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
trans(num_tran++) = {coord_transform_enum::merge,
|
||||
MetaData{to_array<index_t, ndim_low>(low_lengths)},
|
||||
NumDim{ndim_low},
|
||||
low_dims,
|
||||
NumDim{1},
|
||||
Dims{hidden_dim_id_p}};
|
||||
});
|
||||
|
||||
constexpr index_t ndim_bottom = ndim_x;
|
||||
|
||||
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
|
||||
|
||||
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
|
||||
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_y = Ys2RHsMajor::size();
|
||||
constexpr index_t ndim_top = ndim_p + ndim_y;
|
||||
|
||||
auto top_dim_ids = hidden_dim_id_ps;
|
||||
|
||||
{
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
const auto ps_ys_to_xs_adaptor_encoding =
|
||||
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
|
||||
|
||||
// descriptor: [y0, y1, ...] to [d]
|
||||
Lengths y_lengths;
|
||||
index_t d_length = 1;
|
||||
|
||||
for(index_t i = 0; i < ndim_y; ++i)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major[i];
|
||||
index_t rh_minor = ys_to_rhs_minor[i];
|
||||
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
|
||||
y_lengths(i) = y_length;
|
||||
d_length *= y_length;
|
||||
}
|
||||
|
||||
auto tran = make_tuple(coord_transform_enum::unmerge,
|
||||
MetaData{to_array<index_t, ndim_y>(y_lengths)},
|
||||
NumDim{1},
|
||||
Dims{0},
|
||||
NumDim{ndim_y},
|
||||
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
|
||||
|
||||
const auto ys_to_d_adaptor_encoding = make_tuple(
|
||||
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
|
||||
|
||||
return make_tuple(ps_ys_to_xs_adaptor_encoding,
|
||||
ys_to_d_adaptor_encoding,
|
||||
d_length,
|
||||
rh_major_minor_to_hidden_ids);
|
||||
}
|
||||
|
||||
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
|
||||
template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
|
||||
struct tile_distribution_detail
|
||||
{
|
||||
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
|
||||
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
#if 0
|
||||
// this returns a constexpr tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
#endif
|
||||
|
||||
// this returns a static tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
|
||||
|
||||
constexpr auto adaptor_impl =
|
||||
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
|
||||
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
|
||||
constexpr index_t d_length = adaptor_impl.template at<2>();
|
||||
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
|
||||
|
||||
constexpr auto ps_ys_to_xs_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_adaptor =
|
||||
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
|
||||
|
||||
constexpr auto ys_to_d_descriptor =
|
||||
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number<d_length>{});
|
||||
|
||||
//
|
||||
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
|
||||
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
|
||||
|
||||
constexpr auto rh_major_minor_to_hidden_ids =
|
||||
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
|
||||
|
||||
return tile_distribution<
|
||||
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
|
||||
remove_cvref_t<decltype(ys_to_d_descriptor)>,
|
||||
remove_cvref_t<DstrEncode>,
|
||||
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
|
||||
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
|
||||
}
|
||||
|
||||
//***********************************************************************************
|
||||
|
||||
namespace detail {
|
||||
//
|
||||
// slice tensor from x_dim, result in split in y_dim, not p_dim.
|
||||
// We don't support slice cross p_dim (aka, slice different threads)
|
||||
// also, sliced along y_dim need be the first dim of current dim.
|
||||
// Multiply Y dim before sliced dim does not make sense
|
||||
//
|
||||
// e.g
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
|
||||
// totally 16 slices
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
|
||||
// |--> slice along this P dim, will split threads, not supported
|
||||
//
|
||||
// X0 X1
|
||||
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
|
||||
// Y P P Y P Y P Y
|
||||
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
|
||||
// |--> slice along this Y dim, but this Y sim need to split into 2
|
||||
// subdime
|
||||
// the P dim in the left is 1, means actually not crossing P
|
||||
//
|
||||
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
|
||||
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
|
||||
{
|
||||
// NOTE: this function need to be called under constexpr context,
|
||||
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
|
||||
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
|
||||
|
||||
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
|
||||
|
||||
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
|
||||
|
||||
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
|
||||
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
|
||||
constexpr auto src_y_dims = src_y_info[number<0>{}];
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
|
||||
{
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
|
||||
// This lambda will modify some value outside, so c++ will not treat return value as
|
||||
// constexpr
|
||||
// TODO: ugly
|
||||
auto new_h_lengths = transform_tuples(
|
||||
[&](auto h_len, auto id) {
|
||||
constexpr auto sliced_h =
|
||||
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
|
||||
|
||||
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
|
||||
constexpr auto sliced_h_index = sliced_h[number<2>{}];
|
||||
|
||||
// update y_slice_lengths
|
||||
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
|
||||
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
|
||||
|
||||
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
|
||||
"not sliced at y dim, please check");
|
||||
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_slice_lengths(src_y_maps[found_y_index - i]) =
|
||||
sliced_h_lens[sliced_h_index - i];
|
||||
});
|
||||
// TODO: add validations not across p dim
|
||||
|
||||
// NOTE: this y_origin is for all dims, not only current dim
|
||||
// will later use pick to select target dim
|
||||
constexpr auto y_origin = [&]() {
|
||||
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
|
||||
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
|
||||
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
|
||||
|
||||
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
|
||||
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
|
||||
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
|
||||
});
|
||||
return y_origin_;
|
||||
}();
|
||||
|
||||
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
|
||||
src_y_prefix_sum[id + 1],
|
||||
1>::type{};
|
||||
|
||||
set_container_subset(
|
||||
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
|
||||
return sliced_h_lens;
|
||||
},
|
||||
typename Encoding::HsLengthss{},
|
||||
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
|
||||
|
||||
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
|
||||
|
||||
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
|
||||
}
|
||||
();
|
||||
|
||||
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
|
||||
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
|
||||
constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
|
||||
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
|
||||
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
|
||||
|
||||
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
|
||||
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
|
||||
|
||||
return make_tuple(
|
||||
make_static_tile_distribution(
|
||||
tile_distribution_encoding<typename Encoding::RsLengths,
|
||||
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
|
||||
// change the
|
||||
// h_lengths type
|
||||
typename Encoding::Ps2RHssMajor,
|
||||
typename Encoding::Ps2RHssMinor,
|
||||
typename Encoding::Ys2RHsMajor,
|
||||
typename Encoding::Ys2RHsMinor>{}),
|
||||
sliced_y_origins,
|
||||
sliced_y_lengths);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ck_tile
|
||||
760
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
760
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
Normal file
@@ -0,0 +1,760 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RsLengths_, // sequence<...>
|
||||
typename HsLengthss_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
|
||||
typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
|
||||
typename Ys2RHsMajor_, // sequence<...>
|
||||
typename Ys2RHsMinor_> // sequence<...>
|
||||
struct tile_distribution_encoding
|
||||
{
|
||||
using RsLengths = remove_cvref_t<RsLengths_>;
|
||||
using HsLengthss = remove_cvref_t<HsLengthss_>;
|
||||
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
|
||||
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
|
||||
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
|
||||
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
|
||||
|
||||
static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
|
||||
static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
|
||||
|
||||
static constexpr index_t NDimX = HsLengthss::size();
|
||||
static constexpr index_t NDimP = Ps2RHssMajor::size();
|
||||
static constexpr index_t NDimY = Ys2RHsMajor::size();
|
||||
static constexpr index_t NDimR = RsLengths::size();
|
||||
|
||||
// FIXME: move into detail
|
||||
static constexpr auto rs_lengths_ = RsLengths{};
|
||||
static constexpr auto hs_lengthss_ = HsLengthss{};
|
||||
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
|
||||
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
|
||||
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
|
||||
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
|
||||
|
||||
// redundant but useful info
|
||||
// TODO: really bad code, should be over-hauled
|
||||
struct detail
|
||||
{
|
||||
// ndim_rh_major_, ndim_span_mainor_
|
||||
static constexpr index_t ndim_rh_major_ = NDimX + 1;
|
||||
static constexpr index_t ndim_span_major_ = NDimX;
|
||||
|
||||
// ndims_rhs_minor_[ndim_rh_major_]
|
||||
static constexpr auto ndims_rhs_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return rs_lengths_.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return hs_lengthss_[i - number<1>{}].size();
|
||||
}
|
||||
},
|
||||
number<ndim_rh_major_>{});
|
||||
|
||||
// max_ndim_rh_minor_
|
||||
static constexpr index_t max_ndim_rh_minor_ =
|
||||
container_reduce(ndims_rhs_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_lengthss_ =
|
||||
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
|
||||
|
||||
// ys_lengths_
|
||||
static constexpr auto ys_lengths_ = [] {
|
||||
array<index_t, NDimY> ys_lengths_tmp{-1};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
index_t rh_major = ys_to_rhs_major_[i];
|
||||
index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
|
||||
}
|
||||
|
||||
return ys_lengths_tmp;
|
||||
}();
|
||||
|
||||
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_ys_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_major = ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_ys_tmp;
|
||||
}();
|
||||
|
||||
// ndims_span_minor_[NDimY]
|
||||
static constexpr auto ndims_span_minor_ = [] {
|
||||
array<index_t, NDimX> ndims_span_minor{0};
|
||||
|
||||
for(index_t i = 0; i < NDimY; i++)
|
||||
{
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_span_minor(span_major)++;
|
||||
}
|
||||
|
||||
return ndims_span_minor;
|
||||
}();
|
||||
|
||||
// max_ndim_span_minor_
|
||||
static constexpr index_t max_ndim_span_minor_ =
|
||||
container_reduce(ndims_span_minor_, maximize<index_t>{}, 0);
|
||||
|
||||
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
|
||||
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
|
||||
array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
|
||||
{-1}};
|
||||
|
||||
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
|
||||
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
|
||||
|
||||
index_t cnt_ndim_span_minor = 0;
|
||||
|
||||
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
|
||||
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
|
||||
|
||||
if(idim_y >= 0)
|
||||
{
|
||||
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
|
||||
|
||||
cnt_ndim_span_minor++;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return rhs_major_minor_to_span_minor;
|
||||
}();
|
||||
|
||||
// ys_to_span_major_[NDimY]
|
||||
static constexpr auto ys_to_span_major_ =
|
||||
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
|
||||
|
||||
// ys_to_span_minor_[NDimY]
|
||||
static constexpr auto ys_to_span_minor_ = generate_array(
|
||||
[](auto i) {
|
||||
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
|
||||
static constexpr auto distributed_spans_lengthss_ = [] {
|
||||
array<array<index_t, max_ndim_span_minor_>, ndim_span_major_>
|
||||
distributed_spans_lengthss{{-1}};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t rh_major = ys_to_rhs_major_[i];
|
||||
const index_t rh_minor = ys_to_rhs_minor_[i];
|
||||
|
||||
const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
|
||||
|
||||
const index_t span_major = rh_major - 1;
|
||||
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
|
||||
|
||||
distributed_spans_lengthss(span_major)(span_minor) = h_length;
|
||||
});
|
||||
|
||||
return distributed_spans_lengthss;
|
||||
}();
|
||||
|
||||
// ndims_distributed_spans_minor_[ndim_span_major_]
|
||||
static constexpr auto ndims_distributed_spans_minor_ = [] {
|
||||
array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
|
||||
|
||||
static_for<0, NDimY, 1>{}([&](auto i) {
|
||||
const index_t span_major = ys_to_rhs_major_[i] - 1;
|
||||
|
||||
ndims_distributed_spans_minor(span_major)++;
|
||||
});
|
||||
|
||||
return ndims_distributed_spans_minor;
|
||||
}();
|
||||
|
||||
// does_p_own_r_[NDimP][NDimR]
|
||||
static constexpr auto does_p_own_r_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
does_p_own_r(idim_p)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return does_p_own_r;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<bool, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// ps_over_rs_derivative_[NDimP][NDimR]
|
||||
static constexpr auto ps_over_rs_derivative_ = [] {
|
||||
if constexpr(NDimR > 0)
|
||||
{
|
||||
array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
|
||||
|
||||
static_for<0, NDimP, 1>{}([&](auto idim_p) {
|
||||
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
|
||||
|
||||
index_t p_over_rh_derivative = 1;
|
||||
|
||||
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
|
||||
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
|
||||
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
|
||||
|
||||
if constexpr(rh_major == 0)
|
||||
{
|
||||
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
|
||||
}
|
||||
|
||||
p_over_rh_derivative *= rh_length;
|
||||
});
|
||||
});
|
||||
|
||||
return ps_over_rs_derivative;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array<array<index_t, NDimR>, NDimP>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
|
||||
{
|
||||
// <len_d0, len_d1, ...>
|
||||
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
|
||||
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
constexpr index_t size = HsLengthss{}[i].size();
|
||||
return number<size>{};
|
||||
},
|
||||
number<NDimX>{});
|
||||
|
||||
// <0, len_d0, len_d0+len_d1, ...>
|
||||
// e.g. seq<3, 5> --> seq<0, 3, 8>
|
||||
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
|
||||
|
||||
return h_dim_prefix_sum;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
|
||||
{
|
||||
constexpr auto all_ys_2_rhss = transform_sequences(
|
||||
[](auto major, auto minor) constexpr {
|
||||
// <0, 0, len_d0, len_d0+len_d1, ...>
|
||||
constexpr auto x_dim_prefix_sum = merge_sequences(
|
||||
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
|
||||
return x_dim_prefix_sum.at(major) + minor;
|
||||
},
|
||||
Ys2RHsMajor{},
|
||||
Ys2RHsMinor{});
|
||||
|
||||
return all_ys_2_rhss;
|
||||
}
|
||||
|
||||
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
|
||||
template <typename IdxSeq, typename PrefixSumSeq>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
|
||||
{
|
||||
using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
|
||||
|
||||
constexpr auto sorted_dims = typename sorted_idx::type{};
|
||||
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
|
||||
|
||||
constexpr auto sorted_histogram =
|
||||
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
|
||||
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
|
||||
|
||||
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
|
||||
{
|
||||
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding::detail{");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_span_major_: ");
|
||||
print(ndim_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_rhs_minor_: ");
|
||||
print(ndims_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndim_rh_major_: ");
|
||||
print(ndim_rh_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_rh_minor_: ");
|
||||
print(max_ndim_rh_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_lengthss_: ");
|
||||
print(rhs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_lengths_: ");
|
||||
print(ys_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("rhs_major_minor_to_ys_: ");
|
||||
print(rhs_major_minor_to_ys_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_span_minor_: ");
|
||||
print(ndims_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("max_ndim_span_minor_: ");
|
||||
print(max_ndim_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_major_: ");
|
||||
print(ys_to_span_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_span_minor_: ");
|
||||
print(ys_to_span_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("distributed_spans_lengthss_: ");
|
||||
print(distributed_spans_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ndims_distributed_spans_minor_: ");
|
||||
print(ndims_distributed_spans_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_over_rs_derivative_: ");
|
||||
print(ps_over_rs_derivative_);
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tile_distribution_encoding{");
|
||||
//
|
||||
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
|
||||
//
|
||||
printf("rs_lengths_: ");
|
||||
print(rs_lengths_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("hs_lengthss_: ");
|
||||
print(hs_lengthss_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_major_: ");
|
||||
print(ps_to_rhss_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ps_to_rhss_minor_: ");
|
||||
print(ps_to_rhss_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_major_: ");
|
||||
print(ys_to_rhs_major_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("ys_to_rhs_minor_: ");
|
||||
print(ys_to_rhs_minor_);
|
||||
printf(", ");
|
||||
//
|
||||
printf("detail: ");
|
||||
print(detail{});
|
||||
//
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename OuterDstr, typename InnerDstr>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
|
||||
{
|
||||
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
|
||||
|
||||
constexpr index_t NDimHMajor = OuterDstr::NDimX;
|
||||
|
||||
using RsLengths =
|
||||
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
|
||||
|
||||
constexpr auto hs_lengthss = generate_tuple(
|
||||
[&](auto i) {
|
||||
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
|
||||
typename InnerDstr::HsLengthss{}[i]);
|
||||
},
|
||||
number<NDimHMajor>{});
|
||||
|
||||
//
|
||||
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
|
||||
array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
|
||||
|
||||
// R dimension
|
||||
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
|
||||
|
||||
// Hs dimensions
|
||||
static_for<0, NDimHMajor, 1>{}([&](auto i) {
|
||||
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
|
||||
});
|
||||
|
||||
return rhs_major_2_ndim_outer_rhs_minor_;
|
||||
}();
|
||||
|
||||
// Ps2RHssMinor
|
||||
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
|
||||
[&](auto p) {
|
||||
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
|
||||
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
|
||||
|
||||
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
|
||||
|
||||
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_p_2_rhss_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_p_2_rhss_minor_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
|
||||
},
|
||||
number<InnerDstr::NDimP>{});
|
||||
|
||||
// Ys2RHsMinor
|
||||
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
|
||||
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
|
||||
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
|
||||
|
||||
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
|
||||
|
||||
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
|
||||
array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
|
||||
|
||||
for(index_t i = 0; i < ndim_tmp; i++)
|
||||
{
|
||||
index_t rh_major = inner_ys_2_rhs_major[i];
|
||||
|
||||
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
|
||||
|
||||
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
|
||||
}
|
||||
|
||||
return updated_inner_ys_2_rhs_minor__;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto ps_2_rhss_major =
|
||||
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
|
||||
|
||||
constexpr auto ps_2_rhss_minor =
|
||||
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
|
||||
|
||||
//
|
||||
constexpr auto ys_2_rhs_major =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
|
||||
|
||||
constexpr auto ys_2_rhs_minor =
|
||||
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
|
||||
|
||||
return tile_distribution_encoding<RsLengths,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_2_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding_impl(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// FIXME: increase if fail
|
||||
constexpr index_t max_ndim_r_out = 20;
|
||||
constexpr index_t max_ndim_y_out = 20;
|
||||
|
||||
//
|
||||
constexpr index_t ndim_p = InDstr::NDimP;
|
||||
constexpr index_t ndim_x_in = InDstr::NDimX;
|
||||
constexpr index_t ndim_y_in = InDstr::NDimY;
|
||||
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
|
||||
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
|
||||
constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
|
||||
|
||||
// ndims_ps_low
|
||||
constexpr auto ndims_ps_low = generate_array(
|
||||
[&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
|
||||
|
||||
// is_rh_major_in_for_reduce
|
||||
array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
|
||||
{
|
||||
index_t rh_major = reduce_dim_xs_in[i] + 1;
|
||||
|
||||
is_rh_major_in_for_reduce(rh_major) = true;
|
||||
}
|
||||
|
||||
// is_y_in_for_reduce
|
||||
array<bool, ndim_y_in> is_y_in_for_reduce{false};
|
||||
|
||||
for(index_t i = 0; i < ndim_y_in; i++)
|
||||
{
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major])
|
||||
{
|
||||
is_y_in_for_reduce(i) = true;
|
||||
}
|
||||
}
|
||||
|
||||
// is_rh_minor_in_for_y_reduce
|
||||
array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
index_t rh_major = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
if(is_y_in_for_reduce[i])
|
||||
{
|
||||
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
|
||||
}
|
||||
});
|
||||
|
||||
// in2out_rh_major
|
||||
array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
|
||||
index_t cnt_ndim_rh_major_out = 0;
|
||||
|
||||
for(index_t i = 0; i < ndim_rh_major_in; i++)
|
||||
{
|
||||
if(is_rh_major_in_for_reduce[i])
|
||||
{
|
||||
in2out_rh_major(i) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
in2out_rh_major(i) = cnt_ndim_rh_major_out;
|
||||
|
||||
cnt_ndim_rh_major_out++;
|
||||
}
|
||||
}
|
||||
|
||||
// rs_lengths_out, in2out_rh_minor
|
||||
array<index_t, max_ndim_r_out> rs_lengths_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
|
||||
|
||||
// loop over input R dim
|
||||
for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(i) = InDstr::rs_lengths_[i];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(0)(i) = i;
|
||||
}
|
||||
|
||||
// loop over input H Dim
|
||||
index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
|
||||
|
||||
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
|
||||
constexpr auto h_major_in = rh_major_in - I1;
|
||||
|
||||
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
|
||||
|
||||
if(is_rh_major_in_for_reduce[rh_major_in])
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
|
||||
{
|
||||
// rs_lengths_out
|
||||
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
|
||||
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
|
||||
|
||||
cnt_ndim_r_out++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
|
||||
{
|
||||
// in2out_rh_minor
|
||||
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_r_out
|
||||
const index_t ndim_r_out = cnt_ndim_r_out;
|
||||
|
||||
// ndims_hs_minor_out, hs_lengthss_out
|
||||
array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
|
||||
|
||||
index_t cnt_ndim_x_out = 0;
|
||||
|
||||
static_for<0, ndim_x_in, 1>{}([&](auto i) {
|
||||
if(not is_rh_major_in_for_reduce[i + I1])
|
||||
{
|
||||
// ndims_hs_minor_out
|
||||
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
|
||||
|
||||
// hs_lengthss_out
|
||||
static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
|
||||
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
|
||||
|
||||
cnt_ndim_x_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ps_to_rhss_major_out, ps_to_rhss_minor_out
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
|
||||
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
|
||||
|
||||
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
|
||||
static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
|
||||
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
|
||||
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
|
||||
|
||||
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
|
||||
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
});
|
||||
});
|
||||
|
||||
// ys_to_rhs_major_out, ys_to_rhs_minor_out
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
|
||||
array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
|
||||
|
||||
index_t cnt_ndim_y_out = 0;
|
||||
|
||||
static_for<0, ndim_y_in, 1>{}([&](auto i) {
|
||||
if(not is_y_in_for_reduce[i])
|
||||
{
|
||||
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
|
||||
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
|
||||
|
||||
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
|
||||
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
|
||||
|
||||
cnt_ndim_y_out++;
|
||||
}
|
||||
});
|
||||
|
||||
// ndim_y_out
|
||||
const index_t ndim_y_out = cnt_ndim_y_out;
|
||||
|
||||
//
|
||||
return make_tuple(ndim_x_out,
|
||||
ndim_p,
|
||||
ndim_y_out,
|
||||
ndim_r_out,
|
||||
ndims_hs_minor_out,
|
||||
ndims_ps_low,
|
||||
rs_lengths_out,
|
||||
hs_lengthss_out,
|
||||
ps_to_rhss_major_out,
|
||||
ps_to_rhss_minor_out,
|
||||
ys_to_rhs_major_out,
|
||||
ys_to_rhs_minor_out);
|
||||
}
|
||||
|
||||
template <typename InDstr, index_t... InReduceDimXs>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
|
||||
{
|
||||
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
|
||||
|
||||
constexpr index_t ndim_x = impl.template at<0>();
|
||||
constexpr index_t ndim_p = impl.template at<1>();
|
||||
constexpr index_t ndim_y = impl.template at<2>();
|
||||
constexpr index_t ndim_r = impl.template at<3>();
|
||||
constexpr auto ndims_hs_minor = impl.template at<4>();
|
||||
constexpr auto ndims_ps_low = impl.template at<5>();
|
||||
constexpr auto rs_lengths_impl = impl.template at<6>();
|
||||
constexpr auto hs_lengthss_impl = impl.template at<7>();
|
||||
constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
|
||||
constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
|
||||
constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
|
||||
constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
|
||||
|
||||
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
|
||||
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
|
||||
constexpr auto ps_to_rhss_major =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ps_to_rhss_minor =
|
||||
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
|
||||
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
|
||||
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
|
||||
|
||||
return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
|
||||
remove_cvref_t<decltype(hs_lengthss)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ck_tile
|
||||
342
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
342
include/ck_tile/core/tensor/tile_elementwise.hpp
Normal file
@@ -0,0 +1,342 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: support tensors with different distribution
|
||||
template <typename InOutElementFunc,
|
||||
typename... InOutDstrTensors,
|
||||
typename = std::enable_if_t<std::conjunction_v<
|
||||
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
|
||||
InOutDstrTensors&... inout_dstr_tensors)
|
||||
{
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
__type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}(
|
||||
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
|
||||
}
|
||||
|
||||
template <typename InElementFunc,
|
||||
typename... InTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InTensor&... in_dstr_tensors)
|
||||
{
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
__type_pack_element<0, InTensor...>::get_thread_buffer_size();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
out_dstr_tensor.get_thread_buffer()(i) =
|
||||
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
template <typename DstrTensors, typename T>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&value](auto& x) {
|
||||
x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
|
||||
},
|
||||
dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
|
||||
{
|
||||
}
|
||||
|
||||
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
|
||||
// sub-dword tensor...
|
||||
template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
|
||||
CK_TILE_DEVICE void
|
||||
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
|
||||
{
|
||||
using elem_type = typename DstrTensors::DataType;
|
||||
constexpr index_t elem_size = sizeof(elem_type);
|
||||
|
||||
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
|
||||
|
||||
// # bytes per write = 4
|
||||
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
|
||||
{
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
auto& buffer = dstr_tensor.get_thread_buffer();
|
||||
|
||||
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
|
||||
if constexpr(elem_size == 1)
|
||||
{
|
||||
// # elements per write = 4
|
||||
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
|
||||
|
||||
buffer[i_write * 4 + 0] = values.x;
|
||||
buffer[i_write * 4 + 1] = values.y;
|
||||
buffer[i_write * 4 + 2] = values.z;
|
||||
buffer[i_write * 4 + 3] = values.w;
|
||||
}
|
||||
else if constexpr(elem_size == 2)
|
||||
{
|
||||
// # elements per write = 2
|
||||
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
|
||||
|
||||
buffer[i_write * 2 + 0] = values.x;
|
||||
buffer[i_write * 2 + 1] = values.y;
|
||||
}
|
||||
else if constexpr(elem_size == 4)
|
||||
{
|
||||
// # elements per write = 1
|
||||
constexpr elem_type value = 0;
|
||||
|
||||
buffer[i_write] = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "type not supported");
|
||||
}
|
||||
});
|
||||
#else
|
||||
using dvec_t = array<index_t, tensor_bytes / 4>;
|
||||
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
|
||||
for(auto i = 0; i < tensor.size(); i++)
|
||||
tensor.get(i) = v;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
|
||||
dstr_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t v>
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, number<v>)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename DstrTensors>
|
||||
CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
|
||||
{
|
||||
set_tile(dstr_tensor, 0);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// TODO: this is ugly
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 4 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
|
||||
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
|
||||
// so we prepare an uninitialized variable purposely, and turn off the warning
|
||||
int dummy_old;
|
||||
static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) {
|
||||
uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
|
||||
uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
|
||||
in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
|
||||
dummy_old,
|
||||
false); // false -> WORD0
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
using vec_t = array<OutDataType, 4>;
|
||||
|
||||
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 2 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
// TODO: this is rtz cvt, need be very careful
|
||||
for(index_t i = 0; i < thread_buffer_size_pk; i++)
|
||||
{
|
||||
auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
|
||||
in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
|
||||
}
|
||||
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
// this function assume either src or dst (or both) date type is under 1 dword
|
||||
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
using i_type = remove_cvref_t<typename InTensor::DataType>;
|
||||
using o_type = remove_cvref_t<OutDataType>;
|
||||
constexpr index_t i_elem_bytes = sizeof(i_type);
|
||||
constexpr index_t o_elem_bytes = sizeof(o_type);
|
||||
static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
|
||||
|
||||
constexpr index_t bulk_size =
|
||||
(i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
|
||||
static_assert(bulk_size != 0);
|
||||
|
||||
using o_bulk_type =
|
||||
std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
|
||||
constexpr index_t iters = thread_buffer_size / bulk_size;
|
||||
constexpr index_t rems = thread_buffer_size % bulk_size;
|
||||
|
||||
// cast the sequence per-bulk
|
||||
static_for<0, iters, 1>{}([&](auto i) {
|
||||
union bulk_wrapper
|
||||
{
|
||||
o_bulk_type bulk{};
|
||||
o_type data[bulk_size];
|
||||
} o_bulk;
|
||||
|
||||
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
|
||||
static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
|
||||
o_bulk.data[ib.value] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer()
|
||||
.template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
|
||||
});
|
||||
|
||||
// TODO: fixme, should use above!
|
||||
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
|
||||
// o_bulk.data[0] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
|
||||
// o_bulk.data[1] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
|
||||
});
|
||||
|
||||
static_for<0, rems, 1>{}([&](auto r) {
|
||||
// TODO: introducing local scratch pad?
|
||||
auto idx = number<iters * bulk_size + r>{};
|
||||
out_dstr_tensor.get_thread_buffer().at(idx) =
|
||||
static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
#endif
|
||||
} // namespace impl
|
||||
|
||||
template <typename DstType, typename SrcTensor>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
{
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> ||
|
||||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
|
||||
float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#if CK_TILE_USE_PK_FP16_TILE_CAST
|
||||
else if constexpr(std::is_same_v<DstType, fp16_t> &&
|
||||
std::is_same_v<typename SrcTensor::DataType, float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 2 == 0))
|
||||
{
|
||||
return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
|
||||
{
|
||||
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
|
||||
}
|
||||
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InOutElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
}
|
||||
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1167
include/ck_tile/core/tensor/tile_window.hpp
Normal file
1167
include/ck_tile/core/tensor/tile_window.hpp
Normal file
File diff suppressed because it is too large
Load Diff
1218
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
1218
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
File diff suppressed because it is too large
Load Diff
54
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
54
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
|
||||
// input a lds store tile, extract some information from it
|
||||
// used to set m0 value for gfx9 serious
|
||||
template <typename LdsTileWindow_>
|
||||
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
return make_tuple(m0_init_value, size_per_issue);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
232
include/ck_tile/core/tensor/transpose_tile.hpp
Normal file
232
include/ck_tile/core/tensor/transpose_tile.hpp
Normal file
@@ -0,0 +1,232 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
|
||||
const InTensor& in_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
|
||||
"Data type for InTensor and OutTensor must be the same!");
|
||||
|
||||
using DataType = typename InTensor::DataType;
|
||||
|
||||
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
|
||||
|
||||
// y_dim_out_to_in
|
||||
// For swapped Hs tile case I need only get_rh_minor_to_y
|
||||
// since rh_major are already swapped due to swapped Hs.
|
||||
constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) {
|
||||
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
|
||||
|
||||
map<index_t, index_t> rh_minor_to_y_;
|
||||
|
||||
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
|
||||
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
|
||||
|
||||
rh_minor_to_y_(rh_minor) = i;
|
||||
});
|
||||
|
||||
return rh_minor_to_y_;
|
||||
};
|
||||
|
||||
// In swapped Hs case <Y,X> -> <X,Y> tile
|
||||
// we have same rh_major, but reversed rh_minor!
|
||||
constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
|
||||
constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
|
||||
|
||||
// Is this really needed?? Should we have simple reverse here??
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
for(const auto& [rh_minor, y_out] : rh_minor_to_y_out)
|
||||
{
|
||||
y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
|
||||
}
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
constexpr index_t y_dim_vec_in = NDimY - 1;
|
||||
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
|
||||
|
||||
// vector lengths
|
||||
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
|
||||
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
|
||||
|
||||
// # of vectors
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
// SFC
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
|
||||
|
||||
using SFC_Y = space_filling_curve<decltype(y_lengths),
|
||||
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
|
||||
decltype(scalars_per_access)>;
|
||||
|
||||
constexpr index_t num_access = SFC_Y::get_num_of_access();
|
||||
|
||||
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
|
||||
|
||||
if constexpr(num_vec_in == 1 || num_vec_out == 1)
|
||||
{
|
||||
// loop over SFC
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y = SFC_Y::get_index(iAccess);
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y);
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y);
|
||||
|
||||
if constexpr(vec_length_in == 1)
|
||||
{
|
||||
out_tensor.get_thread_buffer()[number<out_offset>{}] =
|
||||
in_tensor.get_thread_buffer()[number<in_offset>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
using Vec = array<DataType, vec_length_in>;
|
||||
out_tensor.get_thread_buffer().template get_as<Vec>(
|
||||
number<out_offset / vec_length_in>{}) =
|
||||
in_tensor.get_thread_buffer().template get_as<Vec>(
|
||||
number<in_offset / vec_length_in>{});
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
// in/out vectors to be transposed
|
||||
thread_buffer<InVec, num_vec_in> in_vectors;
|
||||
thread_buffer<OutVec, num_vec_out> out_vectors;
|
||||
|
||||
// loop over SFC and do transpose
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
// data index [y0, y1, ...] in the order of input tensor
|
||||
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
static_assert(in_offset % vec_length_in == 0);
|
||||
|
||||
in_vectors(i).template get_as<InVec>()(I0) =
|
||||
in_tensor.get_thread_buffer()
|
||||
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
|
||||
});
|
||||
|
||||
// transpose
|
||||
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
|
||||
|
||||
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
|
||||
static_assert(out_offset % vec_length_out == 0);
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVec>(
|
||||
number<out_offset / vec_length_out>{},
|
||||
out_vectors[i].template get_as<OutVec>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename OutTensor, typename InTensor>
|
||||
CK_TILE_DEVICE void transpose_tile2d(OutTensor& out, const InTensor& in)
|
||||
{
|
||||
using InDataType = typename InTensor::DataType;
|
||||
using OutDataType = typename OutTensor::DataType;
|
||||
|
||||
using InTileDistr = typename InTensor::StaticTileDistribution;
|
||||
using OutTileDistr = typename OutTensor::StaticTileDistribution;
|
||||
|
||||
using InDstrEncode = typename InTileDistr::DstrEncode;
|
||||
using OutDstrEncode = typename OutTileDistr::DstrEncode;
|
||||
|
||||
using InThreadTensorDesc = typename InTensor::ThreadTensorDesc;
|
||||
using OutThreadTensorDesc = typename OutTensor::ThreadTensorDesc;
|
||||
|
||||
// Ys:
|
||||
constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
|
||||
constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
|
||||
|
||||
// type convert
|
||||
const auto in_tmp = [&]() {
|
||||
if constexpr(std::is_same_v<OutDataType, InDataType>)
|
||||
{
|
||||
return in;
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
|
||||
}
|
||||
}();
|
||||
|
||||
// Scenario where we switch from tile <Y, X> -> <X, Y> - only 2D tiles!
|
||||
// we preserve Ps but swap Ys: <Y1, Y0> -> <Y0, Y1>
|
||||
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
|
||||
InDstrEncode::hs_lengthss_ == tuple_reverse(OutDstrEncode::hs_lengthss_) &&
|
||||
InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
|
||||
in_thread_desc_lengths == tuple_reverse(out_thread_desc_lengths))
|
||||
// Any condition on Ps ??
|
||||
// InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
|
||||
// InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
|
||||
{
|
||||
detail::transpose_tile2d_impl_in_thread(out, in_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Provided tensors could not be transposed!");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
105
include/ck_tile/core/tensor/update_tile.hpp
Normal file
105
include/ck_tile/core/tensor/update_tile.hpp
Normal file
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr);
|
||||
|
||||
tile_window.update(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto update_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
19
include/ck_tile/core/utility/bit_cast.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x)
|
||||
{
|
||||
static_assert(__has_builtin(__builtin_bit_cast), "");
|
||||
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
|
||||
|
||||
return __builtin_bit_cast(Y, x);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
208
include/ck_tile/core/utility/env.hpp
Normal file
208
include/ck_tile/core/utility/env.hpp
Normal file
@@ -0,0 +1,208 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename... Args>
|
||||
void CK_TILE_ERROR(Args&&... args) noexcept
|
||||
{
|
||||
std::ostringstream oss;
|
||||
(oss << ... << args);
|
||||
std::cerr << "[ERROR] " << oss.str() << std::endl;
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <size_t N>
|
||||
bool is_any_of(const char* const (&names)[N], const std::string& str)
|
||||
{
|
||||
return std::any_of(std::begin(names), std::end(names), [&](const char* inner_str) {
|
||||
return str == inner_str;
|
||||
});
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ParseEnvVal
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct ParseEnvVal<bool>
|
||||
{
|
||||
static bool parse_env_var_value(const char* vp)
|
||||
{
|
||||
std::string value_env_str{vp};
|
||||
|
||||
for(auto& c : value_env_str)
|
||||
{
|
||||
if(std::isalpha(c) != 0)
|
||||
{
|
||||
c = std::tolower(static_cast<unsigned char>(c));
|
||||
}
|
||||
}
|
||||
|
||||
if(is_any_of(enabled_names, value_env_str))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(is_any_of(disabled_names, value_env_str))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Invalid value for env variable");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr const char* enabled_names[] = {"enable", "enabled", "1", "yes", "on", "true"};
|
||||
static constexpr const char* disabled_names[] = {
|
||||
"disable", "disabled", "0", "no", "off", "false"};
|
||||
};
|
||||
|
||||
// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default).
|
||||
// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string).
|
||||
template <>
|
||||
struct ParseEnvVal<uint64_t>
|
||||
{
|
||||
static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ParseEnvVal<std::string>
|
||||
{
|
||||
static std::string parse_env_var_value(const char* vp) { return std::string{vp}; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct EnvVar
|
||||
{
|
||||
private:
|
||||
T value{};
|
||||
bool is_unset = true;
|
||||
|
||||
public:
|
||||
const T& GetValue() const { return value; }
|
||||
|
||||
bool IsUnset() const { return is_unset; }
|
||||
|
||||
void Unset() { is_unset = true; }
|
||||
|
||||
void UpdateValue(const T& val)
|
||||
{
|
||||
is_unset = false;
|
||||
value = val;
|
||||
}
|
||||
|
||||
explicit EnvVar(const char* const name, const T& def_val)
|
||||
{
|
||||
// NOLINTNEXTLINE (concurrency-mt-unsafe)
|
||||
const char* vp = std::getenv(name);
|
||||
if(vp != nullptr) // a value was provided
|
||||
{
|
||||
is_unset = false;
|
||||
value = ParseEnvVal<T>::parse_env_var_value(vp);
|
||||
}
|
||||
else // no value provided, use default value
|
||||
{
|
||||
value = def_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end namespace internal
|
||||
|
||||
// Static inside function hides the variable and provides
|
||||
// thread-safety/locking
|
||||
// Used in global namespace
|
||||
#define CK_TILE_DECLARE_ENV_VAR(name, type, default_val) \
|
||||
namespace ck_tile::env { \
|
||||
struct name \
|
||||
{ \
|
||||
static_assert(std::is_same_v<name, ::ck_tile::env::name>, \
|
||||
"CK_TILE_DECLARE_ENV* must be used in the global namespace"); \
|
||||
using value_type = type; \
|
||||
static ck_tile::internal::EnvVar<type>& Ref() \
|
||||
{ \
|
||||
static ck_tile::internal::EnvVar<type> var{#name, default_val}; \
|
||||
return var; \
|
||||
} \
|
||||
}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_BOOL(name) CK_TILE_DECLARE_ENV_VAR(name, bool, false)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_UINT64(name) CK_TILE_DECLARE_ENV_VAR(name, uint64_t, 0)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_STR(name) CK_TILE_DECLARE_ENV_VAR(name, std::string, "")
|
||||
|
||||
#define CK_TILE_ENV(name) \
|
||||
ck_tile::env::name {}
|
||||
|
||||
template <class EnvVar>
|
||||
inline const std::string& EnvGetString(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, std::string>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsEnabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsDisabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline uint64_t EnvValue(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, uint64_t>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsUnset(EnvVar)
|
||||
{
|
||||
return EnvVar::Ref().IsUnset();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
void EnvUnset(EnvVar)
|
||||
{
|
||||
EnvVar::Ref().Unset();
|
||||
}
|
||||
|
||||
/// Updates the cached value of an environment variable
|
||||
template <typename EnvVar, typename ValueType>
|
||||
void UpdateEnvVar(EnvVar, const ValueType& val)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, ValueType>);
|
||||
EnvVar::Ref().UpdateValue(val);
|
||||
}
|
||||
|
||||
template <typename EnvVar>
|
||||
void UpdateEnvVar(EnvVar, const std::string_view& val)
|
||||
{
|
||||
EnvVar::Ref().UpdateValue(
|
||||
ck_tile::internal::ParseEnvVal<typename EnvVar::value_type>::parse_env_var_value(
|
||||
val.data()));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
232
include/ck_tile/core/utility/functional.hpp
Normal file
232
include/ck_tile/core/utility/functional.hpp
Normal file
@@ -0,0 +1,232 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// F signature: F(number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_for()
|
||||
{
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
|
||||
"NBegin >= NEnd)");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
|
||||
f);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T, T... Is>
|
||||
struct applier
|
||||
{
|
||||
template <typename F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
|
||||
(f(number<Is>{}), ...);
|
||||
}
|
||||
};
|
||||
|
||||
template <int32_t Size> // == sizeof...(Is)
|
||||
using make_applier = __make_integer_seq<applier, index_t, Size>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t N>
|
||||
struct static_for<0, N, 1> : detail::make_applier<N>
|
||||
{
|
||||
using detail::make_applier<N>::operator();
|
||||
};
|
||||
|
||||
struct identity
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
|
||||
{
|
||||
return std::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...>)
|
||||
// CurrentOrderedId: sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
|
||||
{
|
||||
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
|
||||
f, CurrentOrderedId::push_back(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<sequence<>, Orders>
|
||||
{
|
||||
// F signature: F(sequence<...>)
|
||||
// OrderedId: sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::reorder_old_to_new(Orders{}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Indices>
|
||||
struct unpack_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct unpack_impl<sequence<Is...>>
|
||||
{
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1>
|
||||
struct unpack2_impl;
|
||||
|
||||
// TODO: remove this, after properly implementing unpack that takes any number of containers
|
||||
template <index_t... Is, index_t... Js>
|
||||
struct unpack2_impl<sequence<Is...>, sequence<Js...>>
|
||||
{
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
|
||||
{
|
||||
#if 0
|
||||
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
|
||||
std::forward<Y>(y).at(number<Js>{})...);
|
||||
#else
|
||||
return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
|
||||
std::forward<Y>(y).template at<Js>()...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x));
|
||||
}
|
||||
|
||||
// TODO: properly implement unpack that takes any number of containers
|
||||
template <typename F, typename X, typename Y>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
using Y_ = remove_reference_t<Y>;
|
||||
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
|
||||
typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
}
|
||||
|
||||
// z = predicate ? x : y
|
||||
template <bool predicate, typename X, typename Y>
|
||||
constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
{
|
||||
if constexpr(predicate)
|
||||
{
|
||||
return std::forward<X>(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::forward<Y>(y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
173
include/ck_tile/core/utility/functional_with_tuple.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// This file should not be included inside tuple.hpp!
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_impl
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
template <class F, class CurrentUnpackIds>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
|
||||
{
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
template <class RemainLengths, class RamainUnpacks, class Orders>
|
||||
struct static_uford_one_shot_impl
|
||||
{
|
||||
template <class F, class CurrentUnpackIds, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
|
||||
{
|
||||
constexpr auto r_lens_stride =
|
||||
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
|
||||
constexpr auto r_upks_stride =
|
||||
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
|
||||
|
||||
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
|
||||
|
||||
constexpr auto new_pack = generate_tuple(
|
||||
[&](auto idx_) {
|
||||
constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
|
||||
constexpr auto i_pre_pack = number<idx_ / pack_len>{};
|
||||
return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
|
||||
},
|
||||
number<CurrentUnpackIds::size() * pack_len>{});
|
||||
|
||||
static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
|
||||
decltype(RamainUnpacks::pop_front()),
|
||||
Orders>{}(f, new_pack, number<current_acc % current_stride>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
|
||||
{
|
||||
template <class F, class PackedId, index_t current_acc>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
|
||||
{
|
||||
constexpr auto origin_packs = transform_tuples(
|
||||
[](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
|
||||
unpack(f, origin_packs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// TODO: we may unify static_ford/static_uford in the future
|
||||
//
|
||||
// loop over nd space(sequence) with packs
|
||||
// you must make sure the function passed in has same number of argument
|
||||
//
|
||||
// e.g.
|
||||
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
|
||||
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
|
||||
//
|
||||
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
|
||||
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
|
||||
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
|
||||
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
|
||||
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
|
||||
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
|
||||
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
|
||||
// ...
|
||||
template <class Lengths,
|
||||
class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_uford
|
||||
{
|
||||
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
static_for<0, Lengths::size(), 1>{}(
|
||||
[&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
|
||||
{
|
||||
using L_ = decltype(Lengths{} / Unpacks{});
|
||||
|
||||
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id...)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
|
||||
f, make_tuple(sequence<>{}));
|
||||
}
|
||||
|
||||
// this version is friendly for issue function one by one
|
||||
template <class F, index_t i_access>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number<i_access>) const
|
||||
{
|
||||
static_assert(i_access < get_num_of_access());
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
|
||||
detail::static_uford_one_shot_impl<decltype(ordered_lengths),
|
||||
decltype(ordered_unpacks),
|
||||
Orders>{}(
|
||||
f, make_tuple(sequence<>{}), number<i_access>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
22
include/ck_tile/core/utility/ignore.hpp
Normal file
22
include/ck_tile/core/utility/ignore.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
struct ignore_t
|
||||
{
|
||||
template <typename T>
|
||||
constexpr void operator=(T&&) const noexcept
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
inline constexpr detail::ignore_t ignore;
|
||||
|
||||
} // namespace ck_tile
|
||||
22
include/ck_tile/core/utility/literals.hpp
Normal file
22
include/ck_tile/core/utility/literals.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace ck_tile {
|
||||
namespace literals {
|
||||
// [P0330] Literal Suffix for (signed) size_t (C++23)
|
||||
// ref: https://wg21.link/p0330r8
|
||||
inline constexpr std::size_t operator""_uz(unsigned long long size)
|
||||
{
|
||||
return static_cast<std::size_t>(size);
|
||||
}
|
||||
|
||||
inline constexpr std::size_t operator""_zu(unsigned long long size)
|
||||
{
|
||||
return static_cast<std::size_t>(size);
|
||||
}
|
||||
} // namespace literals
|
||||
} // namespace ck_tile
|
||||
257
include/ck_tile/core/utility/magic_div.hpp
Normal file
257
include/ck_tile/core/utility/magic_div.hpp
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// magic number division
|
||||
// Caution:
|
||||
// 1. For uint32_t as dividend: magic number division implementation being used would produce
|
||||
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
|
||||
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
|
||||
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
|
||||
// division implementation for uint32_t is then used. Therefore, dividend value need to be
|
||||
// non-negative.
|
||||
// TODO:
|
||||
// 1. Implement magic number divison for int32_t
|
||||
// 2. Implement magic number divison for unit32_t with 32-bit value range
|
||||
struct magic_division32_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= INT32_MAX)
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint64_t tmp_u64 = static_cast<uint64_t>((1UL << shift_u32) - divisor) << 32;
|
||||
uint32_t multiplier_u32 = tmp_u64 / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
template <auto Divisor, typename = std::enable_if_t<(0 < Divisor)>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(constant<multiplier>{}, constant<shift>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// magic number division
|
||||
// This version on works for divisor and dividended between [0, 1 << 16]
|
||||
struct magic_division16_bit_range
|
||||
{
|
||||
// uint32_t
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor)
|
||||
{
|
||||
// WARNING: magic division is only valid for division inside this range.
|
||||
// assert(divisor >= 1 && divisor <= (1U << 16));
|
||||
|
||||
uint32_t shift_u32 = 0;
|
||||
|
||||
while((1U << shift_u32) < divisor)
|
||||
{
|
||||
shift_u32++;
|
||||
};
|
||||
|
||||
uint32_t one = 1;
|
||||
uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1;
|
||||
|
||||
return make_tuple(multiplier_u32, shift_u32);
|
||||
}
|
||||
|
||||
// integral_constant<uint32_t, .>
|
||||
template <auto Divisor>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
|
||||
{
|
||||
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[number<0>{}];
|
||||
constexpr uint32_t shift = tmp[number<1>{}];
|
||||
|
||||
return make_tuple(constant<multiplier>{}, constant<shift>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (dividend * multiplier) >> 16;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (dividend_u32 * multiplier) >> 16;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
};
|
||||
|
||||
// use 32bit version
|
||||
using magic_division = magic_division32_bit_range;
|
||||
|
||||
struct mdiv
|
||||
{
|
||||
// 1 dword -> 3 dword storage
|
||||
uint32_t divisor;
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE void update(uint32_t divisor_)
|
||||
{
|
||||
divisor = divisor_;
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; }
|
||||
};
|
||||
|
||||
struct mdiv2
|
||||
{
|
||||
// 1 dword -> 2 dword storage, divisor need compute from runtime
|
||||
uint32_t multiplier;
|
||||
uint32_t shift; // TODO: 8 bit is enough
|
||||
|
||||
// prefer construct on host
|
||||
CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_)
|
||||
{
|
||||
auto tmp = magic_division::calculate_magic_numbers(divisor_);
|
||||
|
||||
multiplier = tmp[number<0>{}];
|
||||
shift = tmp[number<1>{}];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
|
||||
{
|
||||
return magic_division::do_magic_division(dividend_, multiplier, shift);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void
|
||||
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
|
||||
{
|
||||
quotient_ = div(dividend_);
|
||||
remainder_ = dividend_ - (quotient_ * divisor_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
122
include/ck_tile/core/utility/philox_rand.hpp
Normal file
122
include/ck_tile/core/utility/philox_rand.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
|
||||
class philox
|
||||
{
|
||||
public:
|
||||
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
|
||||
: seed(reinterpret_cast<const uint2&>(seed_))
|
||||
{
|
||||
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
|
||||
{
|
||||
|
||||
uint4 counter_ = counter;
|
||||
ull2* tmp = reinterpret_cast<ull2*>(&counter_);
|
||||
tmp->y = subsequence;
|
||||
|
||||
uint2 key_ = seed;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 6; i++)
|
||||
{
|
||||
counter_ = philox_single_round(counter_, key_);
|
||||
key_.x += kPhilox10A;
|
||||
key_.y += kPhilox10B;
|
||||
}
|
||||
uint4 output = philox_single_round(counter_, key_);
|
||||
return output;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
|
||||
const unsigned long long subsequence) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
|
||||
out_tmp[0] = tmp_ph.x;
|
||||
out_tmp[1] = tmp_ph.y;
|
||||
out_tmp[2] = tmp_ph.z;
|
||||
out_tmp[3] = tmp_ph.w;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
out_tmp[1] = tmp[start_idx + 2];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
|
||||
const unsigned long long subsequence,
|
||||
const index_t start_idx) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
tmp_ph = get_philox_4x32(subsequence);
|
||||
|
||||
uint32x4_t tmp;
|
||||
tmp[0] = tmp_ph.x;
|
||||
tmp[1] = tmp_ph.y;
|
||||
tmp[2] = tmp_ph.z;
|
||||
tmp[3] = tmp_ph.w;
|
||||
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
|
||||
out_tmp[0] = tmp[start_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2
|
||||
{
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
const uint2 seed;
|
||||
|
||||
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
|
||||
{
|
||||
uint2* res;
|
||||
unsigned long long tmp;
|
||||
tmp = static_cast<unsigned long long>(a) * b;
|
||||
res = reinterpret_cast<uint2*>(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
|
||||
{
|
||||
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
58
include/ck_tile/core/utility/random.hpp
Normal file
58
include/ck_tile/core/utility/random.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// return 0 if data is not fp16 or fp32
|
||||
template <typename T, uint32_t seed_>
|
||||
struct prand_generator_t
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int, T, uint32_t = seed_) { return 0; }
|
||||
};
|
||||
|
||||
// version for fp32
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<float, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
|
||||
{
|
||||
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits ^= x >> 16;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
// version for fp16
|
||||
template <uint32_t seed_>
|
||||
struct prand_generator_t<half_t, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
||||
// So, it can have an effect of using same id for multiple elements when the id is
|
||||
// very large!
|
||||
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
||||
return rng;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
95
include/ck_tile/core/utility/reduce_operator.hpp
Normal file
95
include/ck_tile/core/utility/reduce_operator.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace ReduceOp {
|
||||
// y = ReduceOp(y, x);
|
||||
struct Add
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return type_convert<T>(0.0f);
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
float x_ = type_convert<float>(x);
|
||||
|
||||
return type_convert<T>(y_ + x_);
|
||||
}
|
||||
};
|
||||
|
||||
struct SquareAdd
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return type_convert<T>(0.0f);
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
} // namespace ck_tile
|
||||
116
include/ck_tile/core/utility/static_counter.hpp
Normal file
116
include/ck_tile/core/utility/static_counter.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Context, index_t Start = 0, index_t Step = 1>
|
||||
struct static_counter
|
||||
{
|
||||
public:
|
||||
template <typename Unique>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <typename Unique>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
private:
|
||||
template <index_t I>
|
||||
struct slot
|
||||
{
|
||||
_Pragma("GCC diagnostic push");
|
||||
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
|
||||
friend constexpr bool slot_allocated(slot<I>);
|
||||
_Pragma("GCC diagnostic pop");
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct allocate_slot
|
||||
{
|
||||
friend constexpr bool slot_allocated(slot<I>) { return true; }
|
||||
enum
|
||||
{
|
||||
value = I
|
||||
};
|
||||
};
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t next(index_t)
|
||||
{
|
||||
return next<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
|
||||
// allocate_slot<I>.
|
||||
template <typename Unique, index_t I = 0>
|
||||
static constexpr index_t next(double)
|
||||
{
|
||||
return allocate_slot<I>::value;
|
||||
}
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t current(index_t)
|
||||
{
|
||||
return current<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will return the current counter, or assert
|
||||
// in case next() hasn't been called yet.
|
||||
template <typename Unique, index_t I = Start>
|
||||
static constexpr index_t current(double)
|
||||
{
|
||||
static_assert(I != 0, "You must invoke next() first");
|
||||
|
||||
return I - 1;
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
template <int I>
|
||||
struct static_counter_uniq_;
|
||||
}
|
||||
|
||||
#define MAKE_SC() \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
|
||||
#define MAKE_SC_WITH(start_, step_) \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
|
||||
#define NEXT_SC(c_) c_.next<__COUNTER__>()
|
||||
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
|
||||
|
||||
// Usage:
|
||||
// constexpr auto c = MAKE_SC()
|
||||
// NEXT_SC(c) // -> constexpr 0
|
||||
// NEXT_SC(c) // -> constexpr 1
|
||||
// NEXT_SC(c) // -> constexpr 2
|
||||
} // namespace ck_tile
|
||||
73
include/ck_tile/core/utility/to_sequence.hpp
Normal file
73
include/ck_tile/core/utility/to_sequence.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
// TODO: use c++20 nontype template with struct to implement this
|
||||
|
||||
#if 1
|
||||
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") _Pragma( \
|
||||
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
|
||||
ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
#else
|
||||
// Macro function
|
||||
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
[a, n] { \
|
||||
static_assert(a.size() >= n, "wrong! out of bound"); \
|
||||
static_assert(n <= 10, "not implemented"); \
|
||||
if constexpr(n == 0) \
|
||||
{ \
|
||||
return ck_tile::sequence<>{}; \
|
||||
} \
|
||||
else if constexpr(n == 1) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 2) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 3) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 4) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 5) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 6) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 7) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 8) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 9) \
|
||||
{ \
|
||||
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
|
||||
} \
|
||||
else if constexpr(n == 10) \
|
||||
{ \
|
||||
return ck_tile:: \
|
||||
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
155
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
155
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
@@ -0,0 +1,155 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S: scalar type (or it can be non-scalar type)
|
||||
// NX: # of vector before transpose
|
||||
// NY: # of vector after transpose
|
||||
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
|
||||
template <typename S_, index_t NX, index_t NY>
|
||||
struct transpose_vectors
|
||||
{
|
||||
static constexpr index_t s_per_x = NY;
|
||||
static constexpr index_t s_per_y = NX;
|
||||
|
||||
using S = remove_cvref_t<S_>;
|
||||
|
||||
using VX = array<S, s_per_x>;
|
||||
using VY = array<S, s_per_y>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const thread_buffer<VX, NX>& vx_tuple,
|
||||
thread_buffer<VY, NY>& vy_tuple)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
constexpr auto I4 = number<4>{};
|
||||
|
||||
if constexpr(sizeof(S) == 2)
|
||||
{
|
||||
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
|
||||
|
||||
using S2 = array<S, 2>; // typename array<S, 2>::type;
|
||||
|
||||
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 2>{}([&](auto iy) {
|
||||
static_for<0, NX, 2>{}([&](auto ix) {
|
||||
// 2 16bitx2 data from vx_tuple to be transposed
|
||||
const int32_t x_s2_0 =
|
||||
bit_cast<int32_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
|
||||
const int32_t x_s2_1 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
constexpr int32_t m1 = 0x07060302;
|
||||
|
||||
// transpose 2x2 16bit
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0);
|
||||
const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1);
|
||||
|
||||
// 2 16bitx2 data after transposed
|
||||
vy_tuple(iy).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_0);
|
||||
vy_tuple(iy + I1).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(sizeof(S) == 1)
|
||||
{
|
||||
static_assert(((NX % 4 == 0 && NY % 4 == 0) || (NX % 2 == 0 && NY % 2 == 0)), "wrong!");
|
||||
|
||||
using S4 = array<S, 4>; // typename array<S, 4>::type;
|
||||
using S2 = array<S, 2>; // typename array<S, 4>::type;
|
||||
|
||||
if constexpr(NX % 4 == 0 && NY % 4 == 0)
|
||||
{
|
||||
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 4>{}([&](auto iy) {
|
||||
static_for<0, NX, 4>{}([&](auto ix) {
|
||||
// 4 int8x4 data from vx_tuple
|
||||
const int32_t x_s4_0 =
|
||||
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_1 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_2 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_3 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);
|
||||
|
||||
// transpose
|
||||
int32_t t_s4_0, t_s4_1;
|
||||
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
|
||||
|
||||
constexpr int32_t m0 = 0x05010400;
|
||||
constexpr int32_t m1 = 0x05040100;
|
||||
constexpr int32_t m2 = 0x07060302;
|
||||
constexpr int32_t m3 = 0x07030602;
|
||||
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) ->
|
||||
// 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits
|
||||
// first)
|
||||
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
|
||||
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
|
||||
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
|
||||
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
|
||||
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
|
||||
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
|
||||
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
|
||||
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
|
||||
|
||||
// 4 int8x4 data from vy_tuple
|
||||
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
|
||||
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
|
||||
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
|
||||
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(NX % 2 == 0 && NY % 2 == 0)
|
||||
{
|
||||
static_for<0, NY, 2>{}([&](auto ix) {
|
||||
static_for<0, NX, 2>{}([&](auto iy) {
|
||||
const int16_t x_s2_0 =
|
||||
bit_cast<int16_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
|
||||
const int16_t x_s2_1 =
|
||||
bit_cast<int16_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
constexpr int32_t m1 = 0x07060302;
|
||||
|
||||
const int32_t x0_32 = static_cast<int32_t>(x_s2_0 & 0xFFFF);
|
||||
const int32_t x1_32 = static_cast<int32_t>(x_s2_1 & 0xFFFF);
|
||||
|
||||
const int32_t y_s2_0 = __builtin_amdgcn_perm(x1_32, x0_32, m0);
|
||||
const int32_t y_s2_1 = __builtin_amdgcn_perm(x1_32, x0_32, m1);
|
||||
|
||||
vy_tuple(iy).template get_as<S2>()[ix / I2] =
|
||||
bit_cast<S2>(static_cast<int16_t>(y_s2_0 & 0xFFFF));
|
||||
vy_tuple(iy + I1).template get_as<S2>()[ix / I2] =
|
||||
bit_cast<S2>(static_cast<int16_t>(y_s2_1 & 0xFFFF));
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
130
include/ck_tile/core/utility/type_traits.hpp
Normal file
130
include/ck_tile/core/utility/type_traits.hpp
Normal file
@@ -0,0 +1,130 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// remove_cvref_t
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
|
||||
template <typename From, typename To>
|
||||
struct copy_const
|
||||
{
|
||||
static_assert(!std::is_const_v<From>);
|
||||
|
||||
using type = To;
|
||||
};
|
||||
|
||||
template <typename From, typename To>
|
||||
struct copy_const<const From, To>
|
||||
{
|
||||
using type = std::add_const_t<typename copy_const<From, To>::type>;
|
||||
};
|
||||
|
||||
template <typename From, typename To>
|
||||
using copy_const_t = typename copy_const<From, To>::type;
|
||||
|
||||
namespace detail {
|
||||
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
|
||||
struct detector
|
||||
{
|
||||
using value_t = std::false_type;
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <class Default, template <class...> class Op, class... Args>
|
||||
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
{
|
||||
using value_t = std::true_type;
|
||||
using type = Op<Args...>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct nonesuch
|
||||
{
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
void operator=(nonesuch const&) = delete;
|
||||
};
|
||||
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
using has_is_static = decltype(T::is_static());
|
||||
|
||||
template <typename T>
|
||||
struct is_static_impl
|
||||
{
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_is_static, T>{})
|
||||
return T::is_static();
|
||||
else
|
||||
return std::is_arithmetic<T>::value;
|
||||
}();
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T>
|
||||
using is_static = impl::is_static_impl<remove_cvref_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_static_v = is_static<T>::value;
|
||||
|
||||
// TODO: deprecate this
|
||||
template <typename T>
|
||||
using is_known_at_compile_time = is_static<T>;
|
||||
// TODO: if evaluating a rvalue, e.g. a const integer
|
||||
// , this helper will also return false, which is not good(?)
|
||||
// do we need something like is_constexpr()?
|
||||
|
||||
// FIXME: do we need this anymore?
|
||||
template <
|
||||
typename PY,
|
||||
typename PX,
|
||||
typename std::enable_if<std::is_pointer_v<PY> && std::is_pointer_v<PX>, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic ignored "-Wcast-align"
|
||||
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename CompareTo, typename... Rest>
|
||||
struct is_any_of : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename CompareTo, typename FirstType>
|
||||
struct is_any_of<CompareTo, FirstType> : std::is_same<CompareTo, FirstType>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename CompareTo, typename FirstType, typename... Rest>
|
||||
struct is_any_of<CompareTo, FirstType, Rest...>
|
||||
: std::integral_constant<bool,
|
||||
std::is_same<CompareTo, FirstType>::value ||
|
||||
is_any_of<CompareTo, Rest...>::value>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
69
include/ck_tile/core/utility/unary_element_function.hpp
Normal file
69
include/ck_tile/core/utility/unary_element_function.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename F, typename... Fs>
|
||||
struct composes : private composes<F>
|
||||
{
|
||||
template <typename FirstArg, typename... RestArgs>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(FirstArg&& firstArg, RestArgs&&... restArgs)
|
||||
: composes<F>(std::forward<FirstArg>(firstArg)), inner_(std::forward<RestArgs>(restArgs)...)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Arg>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
|
||||
{
|
||||
return static_cast<const composes<F>&>(*this)(inner_(std::forward<Arg>(arg)));
|
||||
}
|
||||
|
||||
private:
|
||||
composes<Fs...> inner_;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
struct composes<F>
|
||||
{
|
||||
static_assert(!std::is_reference_v<F>);
|
||||
|
||||
template <typename Arg, typename = std::enable_if_t<std::is_constructible_v<F, Arg>>>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit composes(Arg&& arg) : f_(std::forward<Arg>(arg))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Arg,
|
||||
typename = std::enable_if_t<std::is_invocable_v<std::add_const_t<F>&, Arg>>>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Arg&& arg) const
|
||||
{
|
||||
return f_(std::forward<Arg>(arg));
|
||||
}
|
||||
|
||||
private:
|
||||
F f_;
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename... Ts>
|
||||
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
|
||||
|
||||
template <typename SaturateType>
|
||||
struct saturates
|
||||
{
|
||||
// NOTE: this function does not return SaturateType value
|
||||
// it is user's responsiblity to do further cast or not
|
||||
template <typename AccType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
|
||||
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
|
||||
{
|
||||
return clamp(a_,
|
||||
type_convert<AccType>(numeric<SaturateType>::lowest()),
|
||||
type_convert<AccType>(numeric<SaturateType>::max()));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user